Passed
Push — master ( 7861da...b12a19 )
by
unknown
05:56 queued 04:03
created

TestSearchDSLBools.test_query_must_should()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 3
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import time
2
import pdb
3
import copy
4
import threading
5
import logging
6
from multiprocessing import Pool, Process
7
import pytest
8
import numpy as np
9
10
from milvus import DataType
11
from utils import *
12
13
dim = 128
14
segment_row_count = 5000
15
top_k_limit = 2048
16
collection_id = "search"
17
tag = "1970-01-01"
18
insert_interval_time = 1.5
19
nb = 6000
20
top_k = 10
21
nprobe = 1
22
epsilon = 0.001
23
field_name = default_float_vec_field_name
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable default_float_vec_field_name does not seem to be defined.
Loading history...
24
default_fields = gen_default_fields() 
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_fields does not seem to be defined.
Loading history...
25
search_param = {"nprobe": 1}
26
entity = gen_entities(1, is_normal=True)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_entities does not seem to be defined.
Loading history...
27
raw_vector, binary_entity = gen_binary_entities(1)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_binary_entities does not seem to be defined.
Loading history...
28
entities = gen_entities(nb, is_normal=True)
29
raw_vectors, binary_entities = gen_binary_entities(nb)
30
default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_k, 1)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
31
32
def init_data(connect, collection, nb=6000, partition_tags=None):
33
    '''
34
    Generate entities and add it in collection
35
    '''
36
    global entities
37
    if nb == 6000:
38
        insert_entities = entities
39
    else:  
40
        insert_entities = gen_entities(nb, is_normal=True)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_entities does not seem to be defined.
Loading history...
41
    if partition_tags is None:
42
        ids = connect.insert(collection, insert_entities)
43
    else:
44
        ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
45
    connect.flush([collection])
46
    return insert_entities, ids
47
48
def init_binary_data(connect, collection, nb=6000, insert=True, partition_tags=None):
49
    '''
50
    Generate entities and add it in collection
51
    '''
52
    ids = []
53
    global binary_entities
54
    global raw_vectors
55
    if nb == 6000:
56
        insert_entities = binary_entities
57
        insert_raw_vectors = raw_vectors
58
    else:  
59
        insert_raw_vectors, insert_entities = gen_binary_entities(nb)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_binary_entities does not seem to be defined.
Loading history...
60
    if insert is True:
61
        if partition_tags is None:
62
            ids = connect.insert(collection, insert_entities)
63
        else:
64
            ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
65
        connect.flush([collection])
66
    return insert_raw_vectors, insert_entities, ids
67
68
69
class TestSearchBase:
70
71
72
    """
73
    generate valid create_index params
74
    """
75
    @pytest.fixture(
76
        scope="function",
77
        params=gen_index()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_index does not seem to be defined.
Loading history...
78
    )
79
    def get_index(self, request, connect):
80
        if str(connect._cmd("mode")) == "CPU":
81
            if request.param["index_type"] in index_cpu_not_support():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable index_cpu_not_support does not seem to be defined.
Loading history...
82
                pytest.skip("sq8h not support in CPU mode")
83
        return request.param
84
85
    @pytest.fixture(
86
        scope="function",
87
        params=gen_simple_index()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_simple_index does not seem to be defined.
Loading history...
88
    )
89
    def get_simple_index(self, request, connect):
90
        if str(connect._cmd("mode")) == "CPU":
91
            if request.param["index_type"] in index_cpu_not_support():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable index_cpu_not_support does not seem to be defined.
Loading history...
92
                pytest.skip("sq8h not support in CPU mode")
93
        return request.param
94
95
    @pytest.fixture(
96
        scope="function",
97
        params=gen_simple_index()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_simple_index does not seem to be defined.
Loading history...
98
    )
99
    def get_jaccard_index(self, request, connect):
100
        logging.getLogger().info(request.param)
101
        if request.param["index_type"] in binary_support():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_support does not seem to be defined.
Loading history...
102
            return request.param
103
        else:
104
            pytest.skip("Skip index Temporary")
105
106
    @pytest.fixture(
107
        scope="function",
108
        params=gen_simple_index()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_simple_index does not seem to be defined.
Loading history...
109
    )
110
    def get_hamming_index(self, request, connect):
111
        logging.getLogger().info(request.param)
112
        if request.param["index_type"] in binary_support():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_support does not seem to be defined.
Loading history...
113
            return request.param
114
        else:
115
            pytest.skip("Skip index Temporary")
116
117
    @pytest.fixture(
118
        scope="function",
119
        params=gen_simple_index()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_simple_index does not seem to be defined.
Loading history...
120
    )
121
    def get_structure_index(self, request, connect):
122
        logging.getLogger().info(request.param)
123
        if request.param["index_type"] == "FLAT":
124
            return request.param
125
        else:
126
            pytest.skip("Skip index Temporary")
127
128
    """
129
    generate top-k params
130
    """
131
    @pytest.fixture(
132
        scope="function",
133
        params=[1, 10, 2049]
134
    )
135
    def get_top_k(self, request):
136
        yield request.param
137
138
    @pytest.fixture(
139
        scope="function",
140
        params=[1, 10, 1100]
141
    )
142
    def get_nq(self, request):
143
        yield request.param
144
145
    def test_search_flat(self, connect, collection, get_top_k, get_nq):
146
        '''
147
        target: test basic search fuction, all the search params is corrent, change top-k value
148
        method: search with the given vectors, check the result
149
        expected: the length of the result is top_k
150
        '''
151
        top_k = get_top_k
152
        nq = get_nq
153
        entities, ids = init_data(connect, collection)
154
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
155
        if top_k <= top_k_limit:
156
            res = connect.search(collection, query)
157
            assert len(res[0]) == top_k
158
            assert res[0]._distances[0] <= epsilon
159
            assert check_id_result(res[0], ids[0])
160
        else:
161
            with pytest.raises(Exception) as e:
162
                res = connect.search(collection, query)
163
164
    def test_search_field(self, connect, collection, get_top_k, get_nq):
165
        '''
166
        target: test basic search fuction, all the search params is corrent, change top-k value
167
        method: search with the given vectors, check the result
168
        expected: the length of the result is top_k
169
        '''
170
        top_k = get_top_k
171
        nq = get_nq
172
        entities, ids = init_data(connect, collection)
173
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
174
        if top_k <= top_k_limit:
175
            res = connect.search(collection, query, fields=["float_vector"])
176
            assert len(res[0]) == top_k
177
            assert res[0]._distances[0] <= epsilon
178
            assert check_id_result(res[0], ids[0])
179
            # TODO
180
            res = connect.search(collection, query, fields=["float"])
181
            # TODO
182
        else:
183
            with pytest.raises(Exception) as e:
184
                res = connect.search(collection, query)
185
186
    @pytest.mark.level(2)
187
    def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
188
        '''
189
        target: test basic search fuction, all the search params is corrent, test all index params, and build
190
        method: search with the given vectors, check the result
191
        expected: the length of the result is top_k
192
        '''
193
        top_k = get_top_k
194
        nq = get_nq
195
196
        index_type = get_simple_index["index_type"]
197
        if index_type == "IVF_PQ":
198
            pytest.skip("Skip PQ")
199
        entities, ids = init_data(connect, collection)
200
        connect.create_index(collection, field_name, get_simple_index)
201
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
202
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
203
        if top_k > top_k_limit:
204
            with pytest.raises(Exception) as e:
205
                res = connect.search(collection, query)
206
        else:
207
            res = connect.search(collection, query)
208
            assert len(res) == nq
209
            assert len(res[0]) >= top_k
210
            assert res[0]._distances[0] < epsilon
211
            assert check_id_result(res[0], ids[0])
212
213
    def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
214
        '''
215
        target: test basic search fuction, all the search params is corrent, test all index params, and build
216
        method: add vectors into collection, search with the given vectors, check the result
217
        expected: the length of the result is top_k, search collection with partition tag return empty
218
        '''
219
        top_k = get_top_k
220
        nq = get_nq
221
222
        index_type = get_simple_index["index_type"]
223
        if index_type == "IVF_PQ":
224
            pytest.skip("Skip PQ")
225
        connect.create_partition(collection, tag)
226
        entities, ids = init_data(connect, collection)
227
        connect.create_index(collection, field_name, get_simple_index)
228
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
229
        query, vecs = gen_query_vectors_(field_name, entities, top_k, nq, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors_ does not seem to be defined.
Loading history...
230
        if top_k > top_k_limit:
231
            with pytest.raises(Exception) as e:
232
                res = connect.search(collection, query)
233
        else:
234
            res = connect.search(collection, query)
235
            assert len(res) == nq
236
            assert len(res[0]) >= top_k
237
            assert res[0]._distances[0] < epsilon
238
            assert check_id_result(res[0], ids[0])
239
            res = connect.search(collection, query, partition_tags=[tag])
240
            assert len(res) == nq
241
242
    def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
243
        '''
244
        target: test basic search fuction, all the search params is corrent, test all index params, and build
245
        method: search with the given vectors, check the result
246
        expected: the length of the result is top_k
247
        '''
248
        top_k = get_top_k
249
        nq = get_nq
250
251
        index_type = get_simple_index["index_type"]
252
        if index_type == "IVF_PQ":
253
            pytest.skip("Skip PQ")
254
        connect.create_partition(collection, tag)
255
        entities, ids = init_data(connect, collection, partition_tags=tag)
256
        connect.create_index(collection, field_name, get_simple_index)
257
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
258
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
259
        for tags in [[tag], [tag, "new_tag"]]:
260
            if top_k > top_k_limit:
261
                with pytest.raises(Exception) as e:
262
                    res = connect.search(collection, query, partition_tags=tags)
263
            else:
264
                res = connect.search(collection, query, partition_tags=tags)
265
                assert len(res) == nq
266
                assert len(res[0]) >= top_k
267
                assert res[0]._distances[0] < epsilon
268
                assert check_id_result(res[0], ids[0])
269
270
    @pytest.mark.level(2)
271
    def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq):
272
        '''
273
        target: test basic search fuction, all the search params is corrent, test all index params, and build
274
        method: search with the given vectors and tag (tag name not existed in collection), check the result
275
        expected: error raised
276
        '''
277
        top_k = get_top_k
278
        nq = get_nq
279
        entities, ids = init_data(connect, collection)
280
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
281
        if top_k > top_k_limit:
282
            with pytest.raises(Exception) as e:
283
                res = connect.search(collection, query, partition_tags=["new_tag"])
284
        else:
285
            res = connect.search(collection, query, partition_tags=["new_tag"])
286
            assert len(res) == nq
287
            assert len(res[0]) == 0
288
289 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
290
    def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
291
        '''
292
        target: test basic search fuction, all the search params is corrent, test all index params, and build
293
        method: search collection with the given vectors and tags, check the result
294
        expected: the length of the result is top_k
295
        '''
296
        top_k = get_top_k
297
        nq = 2
298
        new_tag = "new_tag"
299
        index_type = get_simple_index["index_type"]
300
        if index_type == "IVF_PQ":
301
            pytest.skip("Skip PQ")
302
        connect.create_partition(collection, tag)
303
        connect.create_partition(collection, new_tag)
304
        entities, ids = init_data(connect, collection, partition_tags=tag)
305
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
306
        connect.create_index(collection, field_name, get_simple_index)
307
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
308
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
309
        if top_k > top_k_limit:
310
            with pytest.raises(Exception) as e:
311
                res = connect.search(collection, query)
312
        else:
313
            res = connect.search(collection, query)
314
            assert check_id_result(res[0], ids[0])
315
            assert not check_id_result(res[1], new_ids[0])
316
            assert res[0]._distances[0] < epsilon
317
            assert res[1]._distances[0] < epsilon
318
            res = connect.search(collection, query, partition_tags=["new_tag"])
319
            assert res[0]._distances[0] > epsilon
320
            assert res[1]._distances[0] > epsilon
321
322
    # TODO:
323 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
324
    def _test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
325
        '''
326
        target: test basic search fuction, all the search params is corrent, test all index params, and build
327
        method: search collection with the given vectors and tags, check the result
328
        expected: the length of the result is top_k
329
        '''
330
        top_k = get_top_k
331
        nq = 2
332
        tag = "tag"
333
        new_tag = "new_tag"
334
        index_type = get_simple_index["index_type"]
335
        if index_type == "IVF_PQ":
336
            pytest.skip("Skip PQ")
337
        connect.create_partition(collection, tag)
338
        connect.create_partition(collection, new_tag)
339
        entities, ids = init_data(connect, collection, partition_tags=tag)
340
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
341
        connect.create_index(collection, field_name, get_simple_index)
342
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
343
        query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
344
        if top_k > top_k_limit:
345
            with pytest.raises(Exception) as e:
346
                res = connect.search(collection, query)
347
        else:
348
            res = connect.search(collection, query, partition_tags=["(.*)tag"])
349
            assert not check_id_result(res[0], ids[0])
350
            assert check_id_result(res[1], new_ids[0])
351
            assert res[0]._distances[0] > epsilon
352
            assert res[1]._distances[0] < epsilon
353
            res = connect.search(collection, query, partition_tags=["new(.*)"])
354
            assert res[0]._distances[0] > epsilon
355
            assert res[1]._distances[0] < epsilon
356
357
    # 
358
    # test for ip metric
359
    # 
360
    @pytest.mark.level(2)
361
    def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
362
        '''
363
        target: test basic search fuction, all the search params is corrent, change top-k value
364
        method: search with the given vectors, check the result
365
        expected: the length of the result is top_k
366
        '''
367
        top_k = get_top_k
368
        nq = get_nq
369
        entities, ids = init_data(connect, collection)
370
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP")
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
371
        if top_k <= top_k_limit:
372
            res = connect.search(collection, query)
373
            assert len(res[0]) == top_k
374
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_inaccuracy does not seem to be defined.
Loading history...
375
            assert check_id_result(res[0], ids[0])
376
        else:
377
            with pytest.raises(Exception) as e:
378
                res = connect.search(collection, query)
379
380
    def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
381
        '''
382
        target: test basic search fuction, all the search params is corrent, test all index params, and build
383
        method: search with the given vectors, check the result
384
        expected: the length of the result is top_k
385
        '''
386
        top_k = get_top_k
387
        nq = get_nq
388
389
        index_type = get_simple_index["index_type"]
390
        if index_type == "IVF_PQ":
391
            pytest.skip("Skip PQ")
392
        entities, ids = init_data(connect, collection)
393
        get_simple_index["metric_type"] = "IP"
394
        connect.create_index(collection, field_name, get_simple_index)
395
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
396
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
397
        if top_k > top_k_limit:
398
            with pytest.raises(Exception) as e:
399
                res = connect.search(collection, query)
400
        else:
401
            res = connect.search(collection, query)
402
            assert len(res) == nq
403
            assert len(res[0]) >= top_k
404
            assert check_id_result(res[0], ids[0])
405
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_inaccuracy does not seem to be defined.
Loading history...
406
407
    @pytest.mark.level(2)
408
    def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
409
        '''
410
        target: test basic search fuction, all the search params is corrent, test all index params, and build
411
        method: add vectors into collection, search with the given vectors, check the result
412
        expected: the length of the result is top_k, search collection with partition tag return empty
413
        '''
414
        top_k = get_top_k
415
        nq = get_nq
416
        metric_type = "IP"
417
        index_type = get_simple_index["index_type"]
418
        if index_type == "IVF_PQ":
419
            pytest.skip("Skip PQ")
420
        connect.create_partition(collection, tag)
421
        entities, ids = init_data(connect, collection)
422
        get_simple_index["metric_type"] = metric_type
423
        connect.create_index(collection, field_name, get_simple_index)
424
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
425
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
426
        if top_k > top_k_limit:
427
            with pytest.raises(Exception) as e:
428
                res = connect.search(collection, query)
429
        else:
430
            res = connect.search(collection, query)
431
            assert len(res) == nq
432
            assert len(res[0]) >= top_k
433
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_inaccuracy does not seem to be defined.
Loading history...
434
            assert check_id_result(res[0], ids[0])
435
            res = connect.search(collection, query, partition_tags=[tag])
436
            assert len(res) == nq
437
438
    @pytest.mark.level(2)
439
    def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
440
        '''
441
        target: test basic search fuction, all the search params is corrent, test all index params, and build
442
        method: search collection with the given vectors and tags, check the result
443
        expected: the length of the result is top_k
444
        '''
445
        top_k = get_top_k
446
        nq = 2
447
        metric_type = "IP"
448
        new_tag = "new_tag"
449
        index_type = get_simple_index["index_type"]
450
        if index_type == "IVF_PQ":
451
            pytest.skip("Skip PQ")
452
        connect.create_partition(collection, tag)
453
        connect.create_partition(collection, new_tag)
454
        entities, ids = init_data(connect, collection, partition_tags=tag)
455
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
456
        get_simple_index["metric_type"] = metric_type
457
        connect.create_index(collection, field_name, get_simple_index)
458
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
459
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
460
        if top_k > top_k_limit:
461
            with pytest.raises(Exception) as e:
462
                res = connect.search(collection, query)
463
        else:
464
            res = connect.search(collection, query)
465
            assert check_id_result(res[0], ids[0])
466
            assert not check_id_result(res[1], new_ids[0])
467
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_inaccuracy does not seem to be defined.
Loading history...
468
            assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
469
            res = connect.search(collection, query, partition_tags=["new_tag"])
470
            assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0])
471
            # TODO:
472
            # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
473
474
    @pytest.mark.level(2)
475
    def test_search_without_connect(self, dis_connect, collection):
476
        '''
477
        target: test search vectors without connection
478
        method: use dis connected instance, call search method and check if search successfully
479
        expected: raise exception
480
        '''
481
        with pytest.raises(Exception) as e:
482
            res = dis_connect.search(collection, default_query)
483
484
    def test_search_collection_name_not_existed(self, connect):
485
        '''
486
        target: search collection not existed
487
        method: search with the random collection_name, which is not in db
488
        expected: status not ok
489
        '''
490
        collection_name = gen_unique_str(collection_id)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_unique_str does not seem to be defined.
Loading history...
491
        with pytest.raises(Exception) as e:
492
            res = connect.search(collection_name, default_query)
493
494 View Code Duplication
    def test_search_distance_l2(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
495
        '''
496
        target: search collection, and check the result: distance
497
        method: compare the return distance value with value computed with Euclidean
498
        expected: the return distance equals to the computed value
499
        '''
500
        nq = 2
501
        search_param = {"nprobe" : 1}
502
        entities, ids = init_data(connect, collection, nb=nq)
503
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
504
        inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
505
        distance_0 = l2(vecs[0], inside_vecs[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable l2 does not seem to be defined.
Loading history...
506
        distance_1 = l2(vecs[0], inside_vecs[1])
507
        res = connect.search(collection, query)
508
        assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_inaccuracy does not seem to be defined.
Loading history...
509
510
    # TODO: distance problem
511 View Code Duplication
    def _test_search_distance_l2_after_index(self, connect, collection, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
512
        '''
513
        target: search collection, and check the result: distance
514
        method: compare the return distance value with value computed with Inner product
515
        expected: the return distance equals to the computed value
516
        '''
517
        index_type = get_simple_index["index_type"]
518
        nq = 2
519
        entities, ids = init_data(connect, collection)
520
        connect.create_index(collection, field_name, get_simple_index)
521
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
522
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
523
        inside_vecs = entities[-1]["values"]
524
        min_distance = 1.0
525
        for i in range(nb):
526
            tmp_dis = l2(vecs[0], inside_vecs[i])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable l2 does not seem to be defined.
Loading history...
527
            if min_distance > tmp_dis:
528
                min_distance = tmp_dis
529
        res = connect.search(collection, query)
530
        assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_inaccuracy does not seem to be defined.
Loading history...
531
532 View Code Duplication
    def test_search_distance_ip(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
533
        '''
534
        target: search collection, and check the result: distance
535
        method: compare the return distance value with value computed with Inner product
536
        expected: the return distance equals to the computed value
537
        '''
538
        nq = 2
539
        metirc_type = "IP"
540
        search_param = {"nprobe" : 1}
541
        entities, ids = init_data(connect, collection, nb=nq)
542
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
543
        inside_query, inside_vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
544
        distance_0 = ip(vecs[0], inside_vecs[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ip does not seem to be defined.
Loading history...
545
        distance_1 = ip(vecs[0], inside_vecs[1])
546
        res = connect.search(collection, query)
547
        assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_inaccuracy does not seem to be defined.
Loading history...
548
549
    # TODO: distance problem
550 View Code Duplication
    def _test_search_distance_ip_after_index(self, connect, collection, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
551
        '''
552
        target: search collection, and check the result: distance
553
        method: compare the return distance value with value computed with Inner product
554
        expected: the return distance equals to the computed value
555
        '''
556
        index_type = get_simple_index["index_type"]
557
        nq = 2
558
        metirc_type = "IP"
559
        entities, ids = init_data(connect, collection)
560
        get_simple_index["metric_type"] = metirc_type
561
        connect.create_index(collection, field_name, get_simple_index)
562
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
563
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
564
        inside_vecs = entities[-1]["values"]
565
        max_distance = 0
566
        for i in range(nb):
567
            tmp_dis = ip(vecs[0], inside_vecs[i])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable ip does not seem to be defined.
Loading history...
568
            if max_distance < tmp_dis:
569
                max_distance = tmp_dis
570
        res = connect.search(collection, query)
571
        assert abs(res[0]._distances[0] - max_distance) <= gen_inaccuracy(res[0]._distances[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_inaccuracy does not seem to be defined.
Loading history...
572
573
    # TODO:
574
    def _test_search_distance_jaccard_flat_index(self, connect, binary_collection):
575
        '''
576
        target: search binary_collection, and check the result: distance
577
        method: compare the return distance value with value computed with Inner product
578
        expected: the return distance equals to the computed value
579
        '''
580
        # from scipy.spatial import distance
581
        nprobe = 512
582
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
583
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
584
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable jaccard does not seem to be defined.
Loading history...
585
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
586
        res = connect.search(binary_collection, query_entities)
587
        assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
588
589
    def _test_search_distance_hamming_flat_index(self, connect, binary_collection):
590
        '''
591
        target: search binary_collection, and check the result: distance
592
        method: compare the return distance value with value computed with Inner product
593
        expected: the return distance equals to the computed value
594
        '''
595
        # from scipy.spatial import distance
596
        nprobe = 512
597
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
598
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
599
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable hamming does not seem to be defined.
Loading history...
600
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
601
        res = connect.search(binary_collection, query_entities)
602
        assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
603
604 View Code Duplication
    def _test_search_distance_substructure_flat_index(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
605
        '''
606
        target: search binary_collection, and check the result: distance
607
        method: compare the return distance value with value computed with Inner product
608
        expected: the return distance equals to the computed value
609
        '''
610
        # from scipy.spatial import distance
611
        nprobe = 512
612
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
613
        index_type = "FLAT"
614
        index_param = {
615
            "nlist": 16384,
616
            "metric_type": "SUBSTRUCTURE"
617
        }
618
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
619
        logging.getLogger().info(connect.get_collection_info(binary_collection))
620
        logging.getLogger().info(connect.get_index_info(binary_collection))
621
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False)
622
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable substructure does not seem to be defined.
Loading history...
623
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
624
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
625
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
626
        logging.getLogger().info(status)
627
        logging.getLogger().info(result)
628
        assert len(result[0]) == 0
629
630 View Code Duplication
    def _test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
631
        '''
632
        target: search binary_collection, and check the result: distance
633
        method: compare the return distance value with value computed with SUB 
634
        expected: the return distance equals to the computed value
635
        '''
636
        # from scipy.spatial import distance
637
        top_k = 3
638
        nprobe = 512
639
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
640
        index_type = "FLAT"
641
        index_param = {
642
            "nlist": 16384,
643
            "metric_type": "SUBSTRUCTURE"
644
        }
645
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
646
        logging.getLogger().info(connect.get_collection_info(binary_collection))
647
        logging.getLogger().info(connect.get_index_info(binary_collection))
648
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_binary_sub_vectors does not seem to be defined.
Loading history...
649
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
650
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
651
        logging.getLogger().info(status)
652
        logging.getLogger().info(result) 
653
        assert len(result[0]) == 1
654
        assert len(result[1]) == 1
655
        assert result[0][0].distance <= epsilon
656
        assert result[0][0].id == ids[0]
657
        assert result[1][0].distance <= epsilon
658
        assert result[1][0].id == ids[1]
659
660 View Code Duplication
    def _test_search_distance_superstructure_flat_index(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
661
        '''
662
        target: search binary_collection, and check the result: distance
663
        method: compare the return distance value with value computed with Inner product
664
        expected: the return distance equals to the computed value
665
        '''
666
        # from scipy.spatial import distance
667
        nprobe = 512
668
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
669
        index_type = "FLAT"
670
        index_param = {
671
            "nlist": 16384,
672
            "metric_type": "SUBSTRUCTURE"
673
        }
674
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
675
        logging.getLogger().info(connect.get_collection_info(binary_collection))
676
        logging.getLogger().info(connect.get_index_info(binary_collection))
677
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False)
678
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable superstructure does not seem to be defined.
Loading history...
679
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
680
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
681
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
682
        logging.getLogger().info(status)
683
        logging.getLogger().info(result)
684
        assert len(result[0]) == 0
685
686 View Code Duplication
    def _test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
687
        '''
688
        target: search binary_collection, and check the result: distance
689
        method: compare the return distance value with value computed with SUPER
690
        expected: the return distance equals to the computed value
691
        '''
692
        # from scipy.spatial import distance
693
        top_k = 3
694
        nprobe = 512
695
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
696
        index_type = "FLAT"
697
        index_param = {
698
            "nlist": 16384,
699
            "metric_type": "SUBSTRUCTURE"
700
        }
701
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
702
        logging.getLogger().info(connect.get_collection_info(binary_collection))
703
        logging.getLogger().info(connect.get_index_info(binary_collection))
704
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_binary_super_vectors does not seem to be defined.
Loading history...
705
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
706
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
707
        logging.getLogger().info(status)
708
        logging.getLogger().info(result)
709
        assert len(result[0]) == 2
710
        assert len(result[1]) == 2
711
        assert result[0][0].id in ids
712
        assert result[0][0].distance <= epsilon
713
        assert result[1][0].id in ids
714
        assert result[1][0].distance <= epsilon
715
716 View Code Duplication
    def _test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
717
        '''
718
        target: search binary_collection, and check the result: distance
719
        method: compare the return distance value with value computed with Inner product
720
        expected: the return distance equals to the computed value
721
        '''
722
        # from scipy.spatial import distance
723
        nprobe = 512
724
        int_vectors, vectors, ids = self.init_binary_data(connect, binary_collection, nb=2)
725
        index_type = "FLAT"
726
        index_param = {
727
            "nlist": 16384,
728
            "metric_type": "TANIMOTO"
729
        }
730
        connect.create_index(binary_collection, binary_field_name, index_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable binary_field_name does not seem to be defined.
Loading history...
731
        logging.getLogger().info(connect.get_collection_info(binary_collection))
732
        logging.getLogger().info(connect.get_index_info(binary_collection))
733
        query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, binary_collection, nb=1, insert=False)
734
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable tanimoto does not seem to be defined.
Loading history...
735
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
736
        search_param = get_search_param(index_type)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_search_param does not seem to be defined.
Loading history...
737
        status, result = connect.search(binary_collection, top_k, query_vecs, params=search_param)
738
        logging.getLogger().info(status)
739
        logging.getLogger().info(result)
740
        assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
741
742
    @pytest.mark.timeout(30)
743
    def test_search_concurrent_multithreads(self, connect, args):
744
        '''
745
        target: test concurrent search with multiprocessess
746
        method: search with 10 processes, each process uses dependent connection
747
        expected: status ok and the returned vectors should be query_records
748
        '''
749
        nb = 100
750
        top_k = 10
751
        threads_num = 4
752
        threads = []
753
        collection = gen_unique_str(collection_id)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_unique_str does not seem to be defined.
Loading history...
754
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
755
        # create collection
756
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_milvus does not seem to be defined.
Loading history...
757
        milvus.create_collection(collection, default_fields)
758
        entities, ids = init_data(milvus, collection)
759
        def search(milvus):
760
            res = connect.search(collection, default_query)
761
            assert len(res) == 1
762
            assert res[0]._entities[0].id in ids
763
            assert res[0]._distances[0] < epsilon
764
        for i in range(threads_num):
765
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
766
            t = threading.Thread(target=search, args=(milvus, ))
767
            threads.append(t)
768
            t.start()
769
            time.sleep(0.2)
770
        for t in threads:
771
            t.join()
772
773
    @pytest.mark.timeout(30)
774
    def test_search_concurrent_multithreads_single_connection(self, connect, args):
775
        '''
776
        target: test concurrent search with multiprocessess
777
        method: search with 10 processes, each process uses dependent connection
778
        expected: status ok and the returned vectors should be query_records
779
        '''
780
        nb = 100
781
        top_k = 10
782
        threads_num = 4
783
        threads = []
784
        collection = gen_unique_str(collection_id)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_unique_str does not seem to be defined.
Loading history...
785
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
786
        # create collection
787
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_milvus does not seem to be defined.
Loading history...
788
        milvus.create_collection(collection, default_fields)
789
        entities, ids = init_data(milvus, collection)
790
        def search(milvus):
791
            res = connect.search(collection, default_query)
792
            assert len(res) == 1
793
            assert res[0]._entities[0].id in ids
794
            assert res[0]._distances[0] < epsilon
795
        for i in range(threads_num):
796
            t = threading.Thread(target=search, args=(milvus, ))
797
            threads.append(t)
798
            t.start()
799
            time.sleep(0.2)
800
        for t in threads:
801
            t.join()
802
803
    def test_search_multi_collections(self, connect, args):
804
        '''
805
        target: test search multi collections of L2
806
        method: add vectors into 10 collections, and search
807
        expected: search status ok, the length of result
808
        '''
809
        num = 10
810
        top_k = 10
811
        nq = 20
812
        for i in range(num):
813
            collection = gen_unique_str(collection_id+str(i))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_unique_str does not seem to be defined.
Loading history...
814
            connect.create_collection(collection, default_fields)
815
            entities, ids = init_data(connect, collection)
816
            assert len(ids) == nb
817
            query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
818
            res = connect.search(collection, query)
819
            assert len(res) == nq
820
            for i in range(nq):
821
                assert check_id_result(res[i], ids[i])
822
                assert res[i]._distances[0] < epsilon
823
                assert res[i]._distances[1] > epsilon
824
825
826
class TestSearchDSL(object):
827
828
    """
829
    ******************************************************************
830
    #  The following cases are used to build invalid query expr
831
    ******************************************************************
832
    """
833
834
    # TODO: assert exception
835
    def test_query_no_must(self, connect, collection):
836
        '''
837
        method: build query without must expr
838
        expected: error raised
839
        '''
840
        # entities, ids = init_data(connect, collection)
841
        query = update_query_expr(default_query, keep_old=False)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
842
        with pytest.raises(Exception) as e:
843
            res = connect.search(collection, query)
844
845
    # TODO: 
846
    def test_query_no_vector_term_only(self, connect, collection):
847
        '''
848
        method: build query without must expr
849
        expected: error raised
850
        '''
851
        # entities, ids = init_data(connect, collection)
852
        expr = {
853
            "must": [gen_default_term_expr]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
854
        }
855
        query = update_query_expr(default_query, keep_old=False, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
856
        with pytest.raises(Exception) as e:
857
            res = connect.search(collection, query)
858
859
    def test_query_wrong_format(self, connect, collection):
860
        '''
861
        method: build query without must expr, with wrong expr name
862
        expected: error raised
863
        '''
864
        # entities, ids = init_data(connect, collection)
865
        expr = {
866
            "must1": [gen_default_term_expr]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
867
        }
868
        query = update_query_expr(default_query, keep_old=False, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
869
        with pytest.raises(Exception) as e:
870
            res = connect.search(collection, query)
871
872
    def test_query_empty(self, connect, collection):
873
        '''
874
        method: search with empty query
875
        expected: error raised
876
        '''
877
        query = {}
878
        with pytest.raises(Exception) as e:
879
            res = connect.search(collection, query)
880
881
    def test_query_with_wrong_format_term(self, connect, collection):
882
        '''
883
        method: build query with wrong term expr
884
        expected: error raised
885
        '''
886
        expr = gen_default_term_expr
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
887
        expr["term"] = 1
888
        query = update_query_expr(default_query, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
889
        with pytest.raises(Exception) as e:
890
            res = connect.search(collection, query)
891
892
893
    """
894
    ******************************************************************
895
    #  The following cases are used to build valid query expr
896
    ******************************************************************
897
    """
898
    def test_query_term_value_not_in(self, connect, collection):
899
        '''
900
        method: build query with vector and term expr, with no term can be filtered
901
        expected: filter pass
902
        '''
903
        entities, ids = init_data(connect, collection)
904
        expr = gen_default_term_expr(values=[100000])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
905
        query = update_query_expr(default_query, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
906
        res = connect.search(collection, query)
907
        # TODO:
908
909
    def test_query_term_value_all_in(self, connect, collection):
910
        '''
911
        method: build query with vector and term expr, with all term can be filtered
912
        expected: filter pass
913
        '''
914
        entities, ids = init_data(connect, collection)
915
        expr = gen_default_term_expr(values=[1])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
916
        query = update_query_expr(default_query, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
917
        res = connect.search(collection, query)
918
        # TODO:
919
920
    def test_query_term_values_not_in(self, connect, collection):
921
        '''
922
        method: build query with vector and term expr, with no term can be filtered
923
        expected: filter pass
924
        '''
925
        entities, ids = init_data(connect, collection)
926
        expr = gen_default_term_expr(values=[i for i in range(100000, 100010)])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
927
        query = update_query_expr(default_query, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
928
        res = connect.search(collection, query)
929
        # TODO:
930
931
    def test_query_term_values_all_in(self, connect, collection):
932
        '''
933
        method: build query with vector and term expr, with all term can be filtered
934
        expected: filter pass
935
        '''
936
        entities, ids = init_data(connect, collection)
937
        expr = gen_default_term_expr()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
938
        query = update_query_expr(default_query, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
939
        res = connect.search(collection, query)
940
        # TODO:
941
942
    def test_query_term_values_parts_in(self, connect, collection):
943
        '''
944
        method: build query with vector and term expr, with parts of term can be filtered
945
        expected: filter pass
946
        '''
947
        entities, ids = init_data(connect, collection)
948
        expr = gen_default_term_expr(values=[i for i in range(nb/2, nb+nb/2)])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
949
        query = update_query_expr(default_query, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
950
        res = connect.search(collection, query)
951
        # TODO:
952
953
    def test_query_term_values_repeat(self, connect, collection):
954
        '''
955
        method: build query with vector and term expr, with the same values
956
        expected: filter pass
957
        '''
958
        entities, ids = init_data(connect, collection)
959
        expr = gen_default_term_expr(values=[1 for i in range(1, nb)])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
960
        query = update_query_expr(default_query, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
961
        res = connect.search(collection, query)
962
        # TODO:
963
964
965
class TestSearchDSLBools(object):
966
967
    """
968
    ******************************************************************
969
    #  The following cases are used to build invalid query expr
970
    ******************************************************************
971
    """
972
    def test_query_no_bool(self, connect, collection):
973
        '''
974
        method: build query without bool expr
975
        expected: error raised
976
        '''
977
        expr = {"bool1": {}}
978
        with pytest.raises(Exception) as e:
979
            res = connect.search(collection, query)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable query does not seem to be defined.
Loading history...
980
981
    def test_query_should_only_term(self, connect, collection):
982
        '''
983
        method: build query without must, with should.term instead
984
        expected: error raised
985
        '''
986
        expr = {"should": gen_default_term_expr}
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
987
        query = update_query_expr(default_query, keep_old=False, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
988
        with pytest.raises(Exception) as e:
989
            res = connect.search(collection, query)
990
991
    def test_query_should_only_vector(self, connect, collection):
992
        '''
993
        method: build query without must, with should.vector instead
994
        expected: error raised
995
        '''
996
        expr = {"should": default_query["bool"]["must"]}
997
        query = update_query_expr(default_query, keep_old=False, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
998
        with pytest.raises(Exception) as e:
999
            res = connect.search(collection, query)
1000
1001
    def test_query_must_not_only_term(self, connect, collection):
1002
        '''
1003
        method: build query without must, with must_not.term instead
1004
        expected: error raised
1005
        '''
1006
        expr = {"must_not": gen_default_term_expr}
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
1007
        query = update_query_expr(default_query, keep_old=False, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
1008
        with pytest.raises(Exception) as e:
1009
            res = connect.search(collection, query)
1010
1011
    def test_query_must_not_vector(self, connect, collection):
1012
        '''
1013
        method: build query without must, with must_not.vector instead
1014
        expected: error raised
1015
        '''
1016
        expr = {"must_not": default_query["bool"]["must"]}
1017
        query = update_query_expr(default_query, keep_old=False, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
1018
        with pytest.raises(Exception) as e:
1019
            res = connect.search(collection, query)
1020
1021
    def test_query_must_should(self, connect, collection):
1022
        '''
1023
        method: build query must, and with should.term
1024
        expected: error raised
1025
        '''
1026
        expr = {"should": gen_default_term_expr}
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_default_term_expr does not seem to be defined.
Loading history...
1027
        query = update_query_expr(default_query, keep_old=True, expr=expr)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable update_query_expr does not seem to be defined.
Loading history...
1028
        with pytest.raises(Exception) as e:
1029
            res = connect.search(collection, query)
1030
1031
1032
"""
1033
******************************************************************
1034
#  The following cases are used to test `search` function 
1035
#  with invalid collection_name, or invalid query expr
1036
******************************************************************
1037
"""
1038
1039
class TestSearchInvalid(object):
1040
1041
    """
1042
    Test search collection with invalid collection names
1043
    """
1044
    @pytest.fixture(
1045
        scope="function",
1046
        params=gen_invalid_strs()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_invalid_strs does not seem to be defined.
Loading history...
1047
    )
1048
    def get_collection_name(self, request):
1049
        yield request.param
1050
1051
    @pytest.fixture(
1052
        scope="function",
1053
        params=gen_invalid_strs()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_invalid_strs does not seem to be defined.
Loading history...
1054
    )
1055
    def get_invalid_tag(self, request):
1056
        yield request.param
1057
1058
    @pytest.fixture(
1059
        scope="function",
1060
        params=gen_invalid_strs()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_invalid_strs does not seem to be defined.
Loading history...
1061
    )
1062
    def get_invalid_field(self, request):
1063
        yield request.param
1064
1065
    @pytest.fixture(
1066
        scope="function",
1067
        params=gen_simple_index()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_simple_index does not seem to be defined.
Loading history...
1068
    )
1069
    def get_simple_index(self, request, connect):
1070
        if str(connect._cmd("mode")) == "CPU":
1071
            if request.param["index_type"] in index_cpu_not_support():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable index_cpu_not_support does not seem to be defined.
Loading history...
1072
                pytest.skip("sq8h not support in CPU mode")
1073
        return request.param
1074
1075
    @pytest.mark.level(2)
1076
    def test_search_with_invalid_collection(self, connect, get_collection_name):
1077
        collection_name = get_collection_name
1078
        with pytest.raises(Exception) as e:
1079
            res = connect.search(collection_name, default_query)
1080
1081
    @pytest.mark.level(1)
1082
    def test_search_with_invalid_tag(self, connect, collection):
1083
        tag = " "
1084
        with pytest.raises(Exception) as e:
1085
            res = connect.search(collection, default_query, partition_tags=tag)
1086
1087
    @pytest.mark.level(2)
1088
    def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field):
1089
        fields = [get_invalid_field]
1090
        with pytest.raises(Exception) as e:
1091
            res = connect.search(collection, default_query, fields=fields)
1092
1093
    @pytest.mark.level(1)
1094
    def test_search_with_not_existed_field_name(self, connect, collection):
1095
        fields = [gen_unique_str("field_name")]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_unique_str does not seem to be defined.
Loading history...
1096
        with pytest.raises(Exception) as e:
1097
            res = connect.search(collection, default_query, fields=fields)
1098
1099
    """
1100
    Test search collection with invalid query
1101
    """
1102
    @pytest.fixture(
1103
        scope="function",
1104
        params=gen_invalid_ints()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_invalid_ints does not seem to be defined.
Loading history...
1105
    )
1106
    def get_top_k(self, request):
1107
        yield request.param
1108
1109
    @pytest.mark.level(1)
1110
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
1111
        '''
1112
        target: test search fuction, with the wrong top_k
1113
        method: search with top_k
1114
        expected: raise an error, and the connection is normal
1115
        '''
1116
        top_k = get_top_k
1117
        default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k
1118
        with pytest.raises(Exception) as e:
1119
            res = connect.search(collection, default_query)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable default_query does not seem to be defined.
Loading history...
1120
1121
    """
1122
    Test search collection with invalid search params
1123
    """
1124
    @pytest.fixture(
1125
        scope="function",
1126
        params=gen_invaild_search_params()
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_invaild_search_params does not seem to be defined.
Loading history...
1127
    )
1128
    def get_search_params(self, request):
1129
        yield request.param
1130
1131
    # TODO: This case can all pass, but it's too slow
1132 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1133
    def _test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
1134
        '''
1135
        target: test search fuction, with the wrong nprobe
1136
        method: search with nprobe
1137
        expected: raise an error, and the connection is normal
1138
        '''
1139
        search_params = get_search_params
1140
        index_type = get_simple_index["index_type"]
1141
        entities, ids = init_data(connect, collection)
1142
        connect.create_index(collection, field_name, get_simple_index)
1143
        if search_params["index_type"] != index_type:
1144
            pytest.skip("Skip case")
1145
        query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
1146
        with pytest.raises(Exception) as e:
1147
            res = connect.search(collection, query)
1148
1149 View Code Duplication
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1150
        '''
1151
        target: test search fuction, with empty search params
1152
        method: search with params
1153
        expected: raise an error, and the connection is normal
1154
        '''
1155
        index_type = get_simple_index["index_type"]
1156
        if args["handler"] == "HTTP":
1157
            pytest.skip("skip in http mode")
1158
        if index_type == "FLAT":
1159
            pytest.skip("skip in FLAT index")
1160
        entities, ids = init_data(connect, collection)
1161
        connect.create_index(collection, field_name, get_simple_index)
1162
        query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params={})
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable gen_query_vectors does not seem to be defined.
Loading history...
1163
        with pytest.raises(Exception) as e:
1164
            res = connect.search(collection, query)
1165
1166
1167
def check_id_result(result, id):
1168
    limit_in = 5
1169
    ids = [entity.id for entity in result]
1170
    if len(result) >= limit_in:
1171
        return id in ids[:limit_in]
1172
    else:
1173
        return id in ids
1174