Passed
Push — master ( b028d7...98336c )
by
unknown
01:59
created

test_search   F

Complexity

Total Complexity 201

Size/Duplication

Total Lines 1680
Duplicated Lines 19.11 %

Importance

Changes 0
Metric Value
eloc 1051
dl 321
loc 1680
rs 1.3959
c 0
b 0
f 0
wmc 201

96 Methods

Rating   Name   Duplication   Size   Complexity  
A TestSearchBase.get_simple_index() 0 9 3
A TestSearchBase.get_structure_index() 0 10 2
A TestSearchBase.get_index() 0 9 3
A TestSearchBase.get_jaccard_index() 0 10 2
A TestSearchBase.get_hamming_index() 0 10 2
A TestSearchBase.test_search_flat_top_k() 18 18 3
A TestSearchBase.get_nq() 0 6 1
A TestSearchBase.test_search_field() 0 21 4
A TestSearchBase.get_top_k() 0 6 1
A TestSearchBase.test_search_flat() 18 18 3
A TestSearchInvalid.test_search_with_invalid_collection() 0 5 2
A TestSearchDSLBools.test_query_no_bool() 0 11 2
A TestSearchBase.test_search_concurrent_multithreads() 0 33 3
A TestSearchDSLBools.test_query_should_only_vector() 0 9 2
A TestSearchBase.test_search_distance_substructure_flat_index() 15 15 1
A TestSearchDSL.test_query_multi_range_different_fields() 0 14 1
A TestSearchDSLBools.test_query_must_not_only_term() 0 9 2
A TestSearchDSL.test_query_range_one_field_not_existed() 0 12 2
A TestSearchDSL.test_query_range_invalid_ranges() 0 13 1
A TestSearchInvalid.test_search_with_not_existed_field_name() 0 5 2
A TestSearchInvalid.get_top_k() 0 6 1
A TestSearchDSLBools.test_query_must_not_vector() 0 9 2
A TestSearchBase.test_search_distance_hamming_flat_index() 15 15 1
A TestSearchBase.test_search_concurrent_multithreads_single_connection() 0 32 3
A TestSearchInvalid.get_simple_index() 0 9 3
A TestSearchBase.test_search_after_delete() 0 44 4
A TestSearchBase.test_search_without_connect() 0 9 2
A TestSearchDSL.get_invalid_term() 0 6 1
A TestSearchBase.test_search_distance_ip_after_index() 26 26 3
A TestSearchBase.test_search_distance_ip() 0 18 1
A TestSearchDSL.test_query_wrong_format() 0 12 2
A TestSearchDSLBools.test_query_must_should() 0 9 2
A TestSearchInvalid.test_search_with_empty_params() 17 17 4
A TestSearchDSL.test_query_term_one_field_not_existed() 0 13 2
A TestSearchDSL.test_query_vector_only() 0 5 1
A TestSearchDSL.test_query_term_field_named_term() 0 23 1
A TestSearchDSL.test_query_no_vector_range_only() 0 12 2
A TestSearchBase.test_search_ip_index_partitions() 0 33 4
A TestSearchInvalid.get_invalid_tag() 0 6 1
A TestSearchDSL.test_query_multi_vectors_same_field() 0 14 2
A TestSearchBase.test_search_multi_collections() 0 22 3
A TestSearchBase.test_search_ip_flat() 0 19 3
A TestSearchDSL.test_query_term_wrong_format() 0 12 2
A TestSearchInvalid._test_search_with_invalid_tag() 0 5 2
A TestSearchInvalid.test_search_with_invalid_top_k() 0 11 2
A TestSearchDSL.test_query_range_wrong_format() 0 12 2
A TestSearchDSL.test_query_single_term_range_has_common() 14 14 1
A TestSearchBase.test_search_ip_index_partition() 0 31 4
A TestSearchInvalid.get_search_params() 0 6 1
A TestSearchDSL.test_query_term_value_empty() 0 10 1
A TestSearchBase.test_search_distance_tanimoto_flat_index() 15 15 1
A TestSearchBase.test_search_distance_substructure_flat_index_B() 0 16 1
A TestSearchDSL.test_query_term_value_all_in() 0 12 1
B TestSearchBase.test_search_index_partition_B() 0 28 5
A TestSearchBase.test_search_index_partitions() 0 32 4
A TestSearchDSL.test_query_range_string_ranges() 0 13 2
A TestSearchDSL.test_query_term_key_error() 0 11 2
A TestSearchBase.test_search_distance_superstructure_flat_index_B() 0 18 1
A TestSearchDSL.test_query_range_key_error() 0 10 2
A TestSearchInvalid.test_search_with_invalid_field_name() 0 5 2
A TestSearchDSL.get_valid_ranges() 0 6 1
A TestSearchDSL.get_invalid_range() 0 6 1
A TestSearchBase.test_search_distance_l2_after_index() 23 23 3
A TestSearchDSL.test_query_term_values_all_in() 0 11 1
A TestSearchBase.test_search_distance_superstructure_flat_index() 15 15 1
A TestSearchDSL.test_query_term_values_not_in() 13 13 1
A TestSearchDSL.test_query_no_vector_term_only() 0 12 2
A TestSearchDSL.test_query_empty() 0 8 2
A TestSearchDSL.test_query_no_must() 0 9 2
A TestSearchDSL.test_query_range_valid_ranges() 0 14 1
A TestSearchBase.test_search_distance_l2() 0 15 1
A TestSearchDSL.test_query_single_term_multi_fields() 0 14 2
A TestSearchDSL.test_query_multi_term_no_common() 14 14 1
A TestSearchBase.test_search_distance_jaccard_flat_index() 14 14 1
A TestSearchDSL.test_query_term_values_parts_in() 12 12 1
A TestSearchBase.test_search_index_partition_C() 0 18 3
A TestSearchDSL.test_query_multi_range_no_common() 0 14 1
A TestSearchBase.test_search_index_partitions_B() 0 32 4
A TestSearchBase.test_search_after_index_different_metric_type() 0 16 1
A TestSearchDSL.test_query_multi_term_different_fields() 13 13 1
A TestSearchDSLBools.test_query_should_only_term() 0 9 2
A TestSearchInvalid.get_invalid_field() 0 6 1
A TestSearchInvalid.test_search_with_invalid_params_binary() 0 15 2
A TestSearchBase.test_search_ip_after_index() 0 27 4
A TestSearchBase.test_search_distance_jaccard_flat_index_L2() 0 15 2
A TestSearchBase.test_search_index_partition() 0 29 4
A TestSearchDSL.test_query_single_term_range_no_common() 13 13 1
A TestSearchDSL.test_query_term_values_repeat() 13 13 1
A TestSearchDSL.test_query_single_range_multi_fields() 0 14 2
A TestSearchBase.test_search_after_index() 0 26 4
A TestSearchDSL.test_query_term_value_not_in() 0 13 1
A TestSearchDSL.test_query_multi_term_has_common() 14 14 1
A TestSearchDSL.test_query_multi_range_has_common() 0 14 1
A TestSearchInvalid.test_search_with_invalid_params() 18 18 4
A TestSearchBase.test_search_collection_name_not_existed() 0 9 2
A TestSearchInvalid.get_collection_name() 0 6 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A init_binary_data() 0 19 4
A init_data() 21 21 5
A check_id_result() 0 7 2

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like test_search often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import time
2
import pdb
3
import copy
4
import threading
5
import logging
6
from multiprocessing import Pool, Process
7
import pytest
8
import numpy as np
9
10
from milvus import DataType
11
from utils import *
12
from constants import *
13
14
uid = "test_search"
15
nq = 1
16
epsilon = 0.001
17
field_name = default_float_vec_field_name
18
binary_field_name = default_binary_vec_field_name
19
search_param = {"nprobe": 1}
20
21
entity = gen_entities(1, is_normal=True)
22
entities = gen_entities(default_nb, is_normal=True)
23
raw_vectors, binary_entities = gen_binary_entities(default_nb)
24
default_query, default_query_vecs = gen_query_vectors(field_name, entities, default_top_k, nq)
25
default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, default_top_k, nq)
26
27 View Code Duplication
def init_data(connect, collection, nb=1200, partition_tags=None, auto_id=True):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
28
    '''
29
    Generate entities and add it in collection
30
    '''
31
    global entities
32
    if nb == 1200:
33
        insert_entities = entities
34
    else:
35
        insert_entities = gen_entities(nb, is_normal=True)
36
    if partition_tags is None:
37
        if auto_id:
38
            ids = connect.insert(collection, insert_entities)
39
        else:
40
            ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
41
    else:
42
        if auto_id:
43
            ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
44
        else:
45
            ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
46
    connect.flush([collection])
47
    return insert_entities, ids
48
49
50
def init_binary_data(connect, collection, nb=1200, insert=True, partition_tags=None):
51
    '''
52
    Generate entities and add it in collection
53
    '''
54
    ids = []
55
    global binary_entities
56
    global raw_vectors
57
    if nb == 1200:
58
        insert_entities = binary_entities
59
        insert_raw_vectors = raw_vectors
60
    else:
61
        insert_raw_vectors, insert_entities = gen_binary_entities(nb)
62
    if insert is True:
63
        if partition_tags is None:
64
            ids = connect.insert(collection, insert_entities)
65
        else:
66
            ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
67
        connect.flush([collection])
68
    return insert_raw_vectors, insert_entities, ids
69
70
71
class TestSearchBase:
72
    """
73
    generate valid create_index params
74
    """
75
76
    @pytest.fixture(
77
        scope="function",
78
        params=gen_index()
79
    )
80
    def get_index(self, request, connect):
81
        if str(connect._cmd("mode")) == "CPU":
82
            if request.param["index_type"] in index_cpu_not_support():
83
                pytest.skip("sq8h not support in CPU mode")
84
        return request.param
85
86
    @pytest.fixture(
87
        scope="function",
88
        params=gen_simple_index()
89
    )
90
    def get_simple_index(self, request, connect):
91
        if str(connect._cmd("mode")) == "CPU":
92
            if request.param["index_type"] in index_cpu_not_support():
93
                pytest.skip("sq8h not support in CPU mode")
94
        return request.param
95
96
    @pytest.fixture(
97
        scope="function",
98
        params=gen_binary_index()
99
    )
100
    def get_jaccard_index(self, request, connect):
101
        logging.getLogger().info(request.param)
102
        if request.param["index_type"] in binary_support():
103
            return request.param
104
        else:
105
            pytest.skip("Skip index Temporary")
106
107
    @pytest.fixture(
108
        scope="function",
109
        params=gen_binary_index()
110
    )
111
    def get_hamming_index(self, request, connect):
112
        logging.getLogger().info(request.param)
113
        if request.param["index_type"] in binary_support():
114
            return request.param
115
        else:
116
            pytest.skip("Skip index Temporary")
117
118
    @pytest.fixture(
119
        scope="function",
120
        params=gen_binary_index()
121
    )
122
    def get_structure_index(self, request, connect):
123
        logging.getLogger().info(request.param)
124
        if request.param["index_type"] == "FLAT":
125
            return request.param
126
        else:
127
            pytest.skip("Skip index Temporary")
128
129
    """
130
    generate top-k params
131
    """
132
133
    @pytest.fixture(
134
        scope="function",
135
        params=[1, 10]
136
    )
137
    def get_top_k(self, request):
138
        yield request.param
139
140
    @pytest.fixture(
141
        scope="function",
142
        params=[1, 10, 1100]
143
    )
144
    def get_nq(self, request):
145
        yield request.param
146
147 View Code Duplication
    def test_search_flat(self, connect, collection, get_top_k, get_nq):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
148
        '''
149
        target: test basic search function, all the search params is corrent, change top-k value
150
        method: search with the given vectors, check the result
151
        expected: the length of the result is top_k
152
        '''
153
        top_k = get_top_k
154
        nq = get_nq
155
        entities, ids = init_data(connect, collection)
156
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
157
        if top_k <= max_top_k:
158
            res = connect.search(collection, query)
159
            assert len(res[0]) == top_k
160
            assert res[0]._distances[0] <= epsilon
161
            assert check_id_result(res[0], ids[0])
162
        else:
163
            with pytest.raises(Exception) as e:
164
                res = connect.search(collection, query)
165
166 View Code Duplication
    def test_search_flat_top_k(self, connect, collection, get_nq):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
167
        '''
168
        target: test basic search function, all the search params is corrent, change top-k value
169
        method: search with the given vectors, check the result
170
        expected: the length of the result is top_k
171
        '''
172
        top_k = 16385
173
        nq = get_nq
174
        entities, ids = init_data(connect, collection)
175
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
176
        if top_k <= max_top_k:
177
            res = connect.search(collection, query)
178
            assert len(res[0]) == top_k
179
            assert res[0]._distances[0] <= epsilon
180
            assert check_id_result(res[0], ids[0])
181
        else:
182
            with pytest.raises(Exception) as e:
183
                res = connect.search(collection, query)
184
185
    def test_search_field(self, connect, collection, get_top_k, get_nq):
186
        '''
187
        target: test basic search function, all the search params is corrent, change top-k value
188
        method: search with the given vectors, check the result
189
        expected: the length of the result is top_k
190
        '''
191
        top_k = get_top_k
192
        nq = get_nq
193
        entities, ids = init_data(connect, collection)
194
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
195
        if top_k <= max_top_k:
196
            res = connect.search(collection, query, fields=["float_vector"])
197
            assert len(res[0]) == top_k
198
            assert res[0]._distances[0] <= epsilon
199
            assert check_id_result(res[0], ids[0])
200
            res = connect.search(collection, query, fields=["float"])
201
            for i in range(nq):
202
                assert entities[1]["values"][:nq][i] in [r.entity.get('float') for r in res[i]]
203
        else:
204
            with pytest.raises(Exception):
205
                connect.search(collection, query)
206
207
    def test_search_after_delete(self, connect, collection, get_top_k, get_nq):
208
        '''
209
        target: test basic search function before and after deletion, all the search params is
210
                corrent, change top-k value.
211
                check issue <a href="https://github.com/milvus-io/milvus/issues/4200">#4200</a>
212
        method: search with the given vectors, check the result
213
        expected: the deleted entities do not exist in the result.
214
        '''
215
        top_k = get_top_k
216
        nq = get_nq
217
218
        entities, ids = init_data(connect, collection, nb=10000)
219
        first_int64_value = entities[0]["values"][0]
220
        first_vector = entities[2]["values"][0]
221
222
        search_param = get_search_param("FLAT")
223
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
224
        vecs[:] = []
225
        vecs.append(first_vector)
226
227
        res = None
228
        if top_k > max_top_k:
229
            with pytest.raises(Exception):
230
                connect.search(collection, query, fields=['int64'])
231
            pytest.skip("top_k value is larger than max_topp_k")
232
        else:
233
            res = connect.search(collection, query, fields=['int64'])
234
            assert len(res) == 1
235
            assert len(res[0]) >= top_k
236
            assert res[0][0].id == ids[0]
237
            assert res[0][0].entity.get("int64") == first_int64_value
238
            assert res[0]._distances[0] < epsilon
239
            assert check_id_result(res[0], ids[0])
240
241
        connect.delete_entity_by_id(collection, ids[:1])
242
        connect.flush([collection])
243
244
        res2 = connect.search(collection, query, fields=['int64'])
245
        assert len(res2) == 1
246
        assert len(res2[0]) >= top_k
247
        assert res2[0][0].id != ids[0]
248
        if top_k > 1:
249
            assert res2[0][0].id == res[0][1].id
250
            assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")
251
252
    # TODO:
253
    @pytest.mark.level(2)
254
    def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
255
        '''
256
        target: test basic search function, all the search params is corrent, test all index params, and build
257
        method: search with the given vectors, check the result
258
        expected: the length of the result is top_k
259
        '''
260
        top_k = get_top_k
261
        nq = get_nq
262
263
        index_type = get_simple_index["index_type"]
264
        if index_type in skip_pq():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable skip_pq does not seem to be defined.
Loading history...
265
            pytest.skip("Skip PQ")
266
        entities, ids = init_data(connect, collection)
267
        connect.create_index(collection, field_name, get_simple_index)
268
        search_param = get_search_param(index_type)
269
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
270
        if top_k > max_top_k:
271
            with pytest.raises(Exception) as e:
272
                res = connect.search(collection, query)
273
        else:
274
            res = connect.search(collection, query)
275
            assert len(res) == nq
276
            assert len(res[0]) >= top_k
277
            assert res[0]._distances[0] < epsilon
278
            assert check_id_result(res[0], ids[0])
279
280
    def test_search_after_index_different_metric_type(self, connect, collection, get_simple_index):
281
        '''
282
        target: test search with different metric_type
283
        method: build index with L2, and search using IP
284
        expected: search ok
285
        '''
286
        search_metric_type = "IP"
287
        index_type = get_simple_index["index_type"]
288
        entities, ids = init_data(connect, collection)
289
        connect.create_index(collection, field_name, get_simple_index)
290
        search_param = get_search_param(index_type)
291
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, metric_type=search_metric_type,
292
                                        search_params=search_param)
293
        res = connect.search(collection, query)
294
        assert len(res) == nq
295
        assert len(res[0]) == default_top_k
296
297
    @pytest.mark.level(2)
298
    def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
299
        '''
300
        target: test basic search function, all the search params is corrent, test all index params, and build
301
        method: add vectors into collection, search with the given vectors, check the result
302
        expected: the length of the result is top_k, search collection with partition tag return empty
303
        '''
304
        top_k = get_top_k
305
        nq = get_nq
306
307
        index_type = get_simple_index["index_type"]
308
        if index_type in skip_pq():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable skip_pq does not seem to be defined.
Loading history...
309
            pytest.skip("Skip PQ")
310
        connect.create_partition(collection, default_tag)
311
        entities, ids = init_data(connect, collection)
312
        connect.create_index(collection, field_name, get_simple_index)
313
        search_param = get_search_param(index_type)
314
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
315
        if top_k > max_top_k:
316
            with pytest.raises(Exception) as e:
317
                res = connect.search(collection, query)
318
        else:
319
            res = connect.search(collection, query)
320
            assert len(res) == nq
321
            assert len(res[0]) >= top_k
322
            assert res[0]._distances[0] < epsilon
323
            assert check_id_result(res[0], ids[0])
324
            res = connect.search(collection, query, partition_tags=[default_tag])
325
            assert len(res) == nq
326
327
    @pytest.mark.level(2)
328
    def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
329
        '''
330
        target: test basic search function, all the search params is corrent, test all index params, and build
331
        method: search with the given vectors, check the result
332
        expected: the length of the result is top_k
333
        '''
334
        top_k = get_top_k
335
        nq = get_nq
336
337
        index_type = get_simple_index["index_type"]
338
        if index_type in skip_pq():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable skip_pq does not seem to be defined.
Loading history...
339
            pytest.skip("Skip PQ")
340
        connect.create_partition(collection, default_tag)
341
        entities, ids = init_data(connect, collection, partition_tags=default_tag)
342
        connect.create_index(collection, field_name, get_simple_index)
343
        search_param = get_search_param(index_type)
344
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
345
        for tags in [[default_tag], [default_tag, "new_tag"]]:
346
            if top_k > max_top_k:
347
                with pytest.raises(Exception) as e:
348
                    res = connect.search(collection, query, partition_tags=tags)
349
            else:
350
                res = connect.search(collection, query, partition_tags=tags)
351
                assert len(res) == nq
352
                assert len(res[0]) >= top_k
353
                assert res[0]._distances[0] < epsilon
354
                assert check_id_result(res[0], ids[0])
355
356
    @pytest.mark.level(2)
357
    def test_search_index_partition_C(self, connect, collection, get_top_k, get_nq):
358
        '''
359
        target: test basic search function, all the search params is corrent, test all index params, and build
360
        method: search with the given vectors and tag (tag name not existed in collection), check the result
361
        expected: error raised
362
        '''
363
        top_k = get_top_k
364
        nq = get_nq
365
        entities, ids = init_data(connect, collection)
366
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq)
367
        if top_k > max_top_k:
368
            with pytest.raises(Exception) as e:
369
                res = connect.search(collection, query, partition_tags=["new_tag"])
370
        else:
371
            res = connect.search(collection, query, partition_tags=["new_tag"])
372
            assert len(res) == nq
373
            assert len(res[0]) == 0
374
375
    @pytest.mark.level(2)
376
    def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
377
        '''
378
        target: test basic search function, all the search params is corrent, test all index params, and build
379
        method: search collection with the given vectors and tags, check the result
380
        expected: the length of the result is top_k
381
        '''
382
        top_k = get_top_k
383
        nq = 2
384
        new_tag = "new_tag"
385
        index_type = get_simple_index["index_type"]
386
        if index_type in skip_pq():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable skip_pq does not seem to be defined.
Loading history...
387
            pytest.skip("Skip PQ")
388
        connect.create_partition(collection, default_tag)
389
        connect.create_partition(collection, new_tag)
390
        entities, ids = init_data(connect, collection, partition_tags=default_tag)
391
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
392
        connect.create_index(collection, field_name, get_simple_index)
393
        search_param = get_search_param(index_type)
394
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
395
        if top_k > max_top_k:
396
            with pytest.raises(Exception) as e:
397
                res = connect.search(collection, query)
398
        else:
399
            res = connect.search(collection, query)
400
            assert check_id_result(res[0], ids[0])
401
            assert not check_id_result(res[1], new_ids[0])
402
            assert res[0]._distances[0] < epsilon
403
            assert res[1]._distances[0] < epsilon
404
            res = connect.search(collection, query, partition_tags=["new_tag"])
405
            assert res[0]._distances[0] > epsilon
406
            assert res[1]._distances[0] > epsilon
407
408
    @pytest.mark.level(2)
409
    def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
410
        '''
411
        target: test basic search function, all the search params is corrent, test all index params, and build
412
        method: search collection with the given vectors and tags, check the result
413
        expected: the length of the result is top_k
414
        '''
415
        top_k = get_top_k
416
        nq = 2
417
        tag = "tag"
418
        new_tag = "new_tag"
419
        index_type = get_simple_index["index_type"]
420
        if index_type in skip_pq():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable skip_pq does not seem to be defined.
Loading history...
421
            pytest.skip("Skip PQ")
422
        connect.create_partition(collection, tag)
423
        connect.create_partition(collection, new_tag)
424
        entities, ids = init_data(connect, collection, partition_tags=tag)
425
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
426
        connect.create_index(collection, field_name, get_simple_index)
427
        search_param = get_search_param(index_type)
428
        query, vecs = gen_query_vectors(field_name, new_entities, top_k, nq, search_params=search_param)
429
        if top_k > max_top_k:
430
            with pytest.raises(Exception) as e:
431
                res = connect.search(collection, query)
432
        else:
433
            res = connect.search(collection, query, partition_tags=["(.*)tag"])
434
            assert not check_id_result(res[0], ids[0])
435
            assert res[0]._distances[0] < epsilon
436
            assert res[1]._distances[0] < epsilon
437
            res = connect.search(collection, query, partition_tags=["new(.*)"])
438
            assert res[0]._distances[0] < epsilon
439
            assert res[1]._distances[0] < epsilon
440
441
    #
442
    # test for ip metric
443
    #
444
    @pytest.mark.level(2)
445
    def test_search_ip_flat(self, connect, collection, get_simple_index, get_top_k, get_nq):
446
        '''
447
        target: test basic search function, all the search params is corrent, change top-k value
448
        method: search with the given vectors, check the result
449
        expected: the length of the result is top_k
450
        '''
451
        top_k = get_top_k
452
        nq = get_nq
453
        entities, ids = init_data(connect, collection)
454
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP")
455
        if top_k <= max_top_k:
456
            res = connect.search(collection, query)
457
            assert len(res[0]) == top_k
458
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
459
            assert check_id_result(res[0], ids[0])
460
        else:
461
            with pytest.raises(Exception) as e:
462
                res = connect.search(collection, query)
463
464
    @pytest.mark.level(2)
465
    def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
466
        '''
467
        target: test basic search function, all the search params is corrent, test all index params, and build
468
        method: search with the given vectors, check the result
469
        expected: the length of the result is top_k
470
        '''
471
        top_k = get_top_k
472
        nq = get_nq
473
474
        index_type = get_simple_index["index_type"]
475
        if index_type in skip_pq():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable skip_pq does not seem to be defined.
Loading history...
476
            pytest.skip("Skip PQ")
477
        entities, ids = init_data(connect, collection)
478
        get_simple_index["metric_type"] = "IP"
479
        connect.create_index(collection, field_name, get_simple_index)
480
        search_param = get_search_param(index_type)
481
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
482
        if top_k > max_top_k:
483
            with pytest.raises(Exception) as e:
484
                res = connect.search(collection, query)
485
        else:
486
            res = connect.search(collection, query)
487
            assert len(res) == nq
488
            assert len(res[0]) >= top_k
489
            assert check_id_result(res[0], ids[0])
490
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
491
492
    @pytest.mark.level(2)
493
    def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
494
        '''
495
        target: test basic search function, all the search params is corrent, test all index params, and build
496
        method: add vectors into collection, search with the given vectors, check the result
497
        expected: the length of the result is top_k, search collection with partition tag return empty
498
        '''
499
        top_k = get_top_k
500
        nq = get_nq
501
        metric_type = "IP"
502
        index_type = get_simple_index["index_type"]
503
        if index_type in skip_pq():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable skip_pq does not seem to be defined.
Loading history...
504
            pytest.skip("Skip PQ")
505
        connect.create_partition(collection, default_tag)
506
        entities, ids = init_data(connect, collection)
507
        get_simple_index["metric_type"] = metric_type
508
        connect.create_index(collection, field_name, get_simple_index)
509
        search_param = get_search_param(index_type)
510
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type=metric_type,
511
                                        search_params=search_param)
512
        if top_k > max_top_k:
513
            with pytest.raises(Exception) as e:
514
                res = connect.search(collection, query)
515
        else:
516
            res = connect.search(collection, query)
517
            assert len(res) == nq
518
            assert len(res[0]) >= top_k
519
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
520
            assert check_id_result(res[0], ids[0])
521
            res = connect.search(collection, query, partition_tags=[default_tag])
522
            assert len(res) == nq
523
524
    @pytest.mark.level(2)
525
    def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
526
        '''
527
        target: test basic search function, all the search params is corrent, test all index params, and build
528
        method: search collection with the given vectors and tags, check the result
529
        expected: the length of the result is top_k
530
        '''
531
        top_k = get_top_k
532
        nq = 2
533
        metric_type = "IP"
534
        new_tag = "new_tag"
535
        index_type = get_simple_index["index_type"]
536
        if index_type in skip_pq():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable skip_pq does not seem to be defined.
Loading history...
537
            pytest.skip("Skip PQ")
538
        connect.create_partition(collection, default_tag)
539
        connect.create_partition(collection, new_tag)
540
        entities, ids = init_data(connect, collection, partition_tags=default_tag)
541
        new_entities, new_ids = init_data(connect, collection, nb=6001, partition_tags=new_tag)
542
        get_simple_index["metric_type"] = metric_type
543
        connect.create_index(collection, field_name, get_simple_index)
544
        search_param = get_search_param(index_type)
545
        query, vecs = gen_query_vectors(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param)
546
        if top_k > max_top_k:
547
            with pytest.raises(Exception) as e:
548
                res = connect.search(collection, query)
549
        else:
550
            res = connect.search(collection, query)
551
            assert check_id_result(res[0], ids[0])
552
            assert not check_id_result(res[1], new_ids[0])
553
            assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
554
            assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
555
            res = connect.search(collection, query, partition_tags=["new_tag"])
556
            assert res[0]._distances[0] < 1 - gen_inaccuracy(res[0]._distances[0])
557
            # TODO:
558
            # assert res[1]._distances[0] >= 1 - gen_inaccuracy(res[1]._distances[0])
559
560
    @pytest.mark.level(2)
561
    def test_search_without_connect(self, dis_connect, collection):
562
        '''
563
        target: test search vectors without connection
564
        method: use dis connected instance, call search method and check if search successfully
565
        expected: raise exception
566
        '''
567
        with pytest.raises(Exception) as e:
568
            res = dis_connect.search(collection, default_query)
569
570
    def test_search_collection_name_not_existed(self, connect):
571
        '''
572
        target: search collection not existed
573
        method: search with the random collection_name, which is not in db
574
        expected: status not ok
575
        '''
576
        collection_name = gen_unique_str(uid)
577
        with pytest.raises(Exception) as e:
578
            res = connect.search(collection_name, default_query)
579
580
    def test_search_distance_l2(self, connect, collection):
581
        '''
582
        target: search collection, and check the result: distance
583
        method: compare the return distance value with value computed with Euclidean
584
        expected: the return distance equals to the computed value
585
        '''
586
        nq = 2
587
        search_param = {"nprobe": 1}
588
        entities, ids = init_data(connect, collection, nb=nq)
589
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, search_params=search_param)
590
        inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, search_params=search_param)
591
        distance_0 = l2(vecs[0], inside_vecs[0])
592
        distance_1 = l2(vecs[0], inside_vecs[1])
593
        res = connect.search(collection, query)
594
        assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
595
596 View Code Duplication
    def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
597
        '''
598
        target: search collection, and check the result: distance
599
        method: compare the return distance value with value computed with Inner product
600
        expected: the return distance equals to the computed value
601
        '''
602
        index_type = get_simple_index["index_type"]
603
        nq = 2
604
        entities, ids = init_data(connect, id_collection, auto_id=False)
605
        connect.create_index(id_collection, field_name, get_simple_index)
606
        search_param = get_search_param(index_type)
607
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, search_params=search_param)
608
        inside_vecs = entities[-1]["values"]
609
        min_distance = 1.0
610
        min_id = None
611
        for i in range(default_nb):
612
            tmp_dis = l2(vecs[0], inside_vecs[i])
613
            if min_distance > tmp_dis:
614
                min_distance = tmp_dis
615
                min_id = ids[i]
616
        res = connect.search(id_collection, query)
617
        tmp_epsilon = epsilon
618
        check_id_result(res[0], min_id)
619
        # if index_type in ["ANNOY", "IVF_PQ"]:
620
        #     tmp_epsilon = 0.1
621
        # TODO:
622
        # assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon
623
624
    @pytest.mark.level(2)
625
    def test_search_distance_ip(self, connect, collection):
626
        '''
627
        target: search collection, and check the result: distance
628
        method: compare the return distance value with value computed with Inner product
629
        expected: the return distance equals to the computed value
630
        '''
631
        nq = 2
632
        metirc_type = "IP"
633
        search_param = {"nprobe": 1}
634
        entities, ids = init_data(connect, collection, nb=nq)
635
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, metric_type=metirc_type,
636
                                        search_params=search_param)
637
        inside_query, inside_vecs = gen_query_vectors(field_name, entities, default_top_k, nq, search_params=search_param)
638
        distance_0 = ip(vecs[0], inside_vecs[0])
639
        distance_1 = ip(vecs[0], inside_vecs[1])
640
        res = connect.search(collection, query)
641
        assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
642
643 View Code Duplication
    def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
644
        '''
645
        target: search collection, and check the result: distance
646
        method: compare the return distance value with value computed with Inner product
647
        expected: the return distance equals to the computed value
648
        '''
649
        index_type = get_simple_index["index_type"]
650
        nq = 2
651
        metirc_type = "IP"
652
        entities, ids = init_data(connect, id_collection, auto_id=False)
653
        get_simple_index["metric_type"] = metirc_type
654
        connect.create_index(id_collection, field_name, get_simple_index)
655
        search_param = get_search_param(index_type)
656
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, nq, rand_vector=True, metric_type=metirc_type,
657
                                        search_params=search_param)
658
        inside_vecs = entities[-1]["values"]
659
        max_distance = 0
660
        max_id = None
661
        for i in range(default_nb):
662
            tmp_dis = ip(vecs[0], inside_vecs[i])
663
            if max_distance < tmp_dis:
664
                max_distance = tmp_dis
665
                max_id = ids[i]
666
        res = connect.search(id_collection, query)
667
        tmp_epsilon = epsilon
668
        check_id_result(res[0], max_id)
669
        # if index_type in ["ANNOY", "IVF_PQ"]:
670
        #     tmp_epsilon = 0.1
671
        # TODO:
672
        # assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
673
674 View Code Duplication
    def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
675
        '''
676
        target: search binary_collection, and check the result: distance
677
        method: compare the return distance value with value computed with L2
678
        expected: the return distance equals to the computed value
679
        '''
680
        nq = 1
681
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
682
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
683
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
684
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
685
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="JACCARD")
686
        res = connect.search(binary_collection, query)
687
        assert abs(res[0]._distances[0] - min(distance_0, distance_1)) <= epsilon
688
689
    @pytest.mark.level(2)
690
    def test_search_distance_jaccard_flat_index_L2(self, connect, binary_collection):
691
        '''
692
        target: search binary_collection, and check the result: distance
693
        method: compare the return distance value with value computed with L2
694
        expected: the return distance equals to the computed value
695
        '''
696
        nq = 1
697
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
698
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
699
        distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
700
        distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
701
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="L2")
702
        with pytest.raises(Exception) as e:
703
            res = connect.search(binary_collection, query)
704
705 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
706
    def test_search_distance_hamming_flat_index(self, connect, binary_collection):
707
        '''
708
        target: search binary_collection, and check the result: distance
709
        method: compare the return distance value with value computed with Inner product
710
        expected: the return distance equals to the computed value
711
        '''
712
        nq = 1
713
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
714
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
715
        distance_0 = hamming(query_int_vectors[0], int_vectors[0])
716
        distance_1 = hamming(query_int_vectors[0], int_vectors[1])
717
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="HAMMING")
718
        res = connect.search(binary_collection, query)
719
        assert abs(res[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
720
721 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
722
    def test_search_distance_substructure_flat_index(self, connect, binary_collection):
723
        '''
724
        target: search binary_collection, and check the result: distance
725
        method: compare the return distance value with value computed with Inner product
726
        expected: the return distance equals to the computed value
727
        '''
728
        nq = 1
729
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
730
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
731
        distance_0 = substructure(query_int_vectors[0], int_vectors[0])
732
        distance_1 = substructure(query_int_vectors[0], int_vectors[1])
733
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="SUBSTRUCTURE")
734
        res = connect.search(binary_collection, query)
735
        assert len(res[0]) == 0
736
737
    @pytest.mark.level(2)
738
    def test_search_distance_substructure_flat_index_B(self, connect, binary_collection):
739
        '''
740
        target: search binary_collection, and check the result: distance
741
        method: compare the return distance value with value computed with SUB
742
        expected: the return distance equals to the computed value
743
        '''
744
        top_k = 3
745
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
746
        query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
747
        query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUBSTRUCTURE", replace_vecs=query_vecs)
748
        res = connect.search(binary_collection, query)
749
        assert res[0][0].distance <= epsilon
750
        assert res[0][0].id == ids[0]
751
        assert res[1][0].distance <= epsilon
752
        assert res[1][0].id == ids[1]
753
754 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
755
    def test_search_distance_superstructure_flat_index(self, connect, binary_collection):
756
        '''
757
        target: search binary_collection, and check the result: distance
758
        method: compare the return distance value with value computed with Inner product
759
        expected: the return distance equals to the computed value
760
        '''
761
        nq = 1
762
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
763
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
764
        distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
765
        distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
766
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="SUPERSTRUCTURE")
767
        res = connect.search(binary_collection, query)
768
        assert len(res[0]) == 0
769
770
    @pytest.mark.level(2)
771
    def test_search_distance_superstructure_flat_index_B(self, connect, binary_collection):
772
        '''
773
        target: search binary_collection, and check the result: distance
774
        method: compare the return distance value with value computed with SUPER
775
        expected: the return distance equals to the computed value
776
        '''
777
        top_k = 3
778
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
779
        query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
780
        query, vecs = gen_query_vectors(binary_field_name, entities, top_k, nq, metric_type="SUPERSTRUCTURE", replace_vecs=query_vecs)
781
        res = connect.search(binary_collection, query)
782
        assert len(res[0]) == 2
783
        assert len(res[1]) == 2
784
        assert res[0][0].id in ids
785
        assert res[0][0].distance <= epsilon
786
        assert res[1][0].id in ids
787
        assert res[1][0].distance <= epsilon
788
789 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
790
    def test_search_distance_tanimoto_flat_index(self, connect, binary_collection):
791
        '''
792
        target: search binary_collection, and check the result: distance
793
        method: compare the return distance value with value computed with Inner product
794
        expected: the return distance equals to the computed value
795
        '''
796
        nq = 1
797
        int_vectors, entities, ids = init_binary_data(connect, binary_collection, nb=2)
798
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
799
        distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
800
        distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
801
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, metric_type="TANIMOTO")
802
        res = connect.search(binary_collection, query)
803
        assert abs(res[0][0].distance - min(distance_0, distance_1)) <= epsilon
804
805
    @pytest.mark.level(2)
806
    @pytest.mark.timeout(30)
807
    def test_search_concurrent_multithreads(self, connect, args):
808
        '''
809
        target: test concurrent search with multiprocessess
810
        method: search with 10 processes, each process uses dependent connection
811
        expected: status ok and the returned vectors should be query_records
812
        '''
813
        nb = 100
814
        top_k = 10
815
        threads_num = 4
816
        threads = []
817
        collection = gen_unique_str(uid)
818
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
819
        # create collection
820
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_milvus does not seem to be defined.
Loading history...
821
        milvus.create_collection(collection, default_fields)
822
        entities, ids = init_data(milvus, collection)
823
824
        def search(milvus):
825
            res = connect.search(collection, default_query)
826
            assert len(res) == 1
827
            assert res[0]._entities[0].id in ids
828
            assert res[0]._distances[0] < epsilon
829
830
        for i in range(threads_num):
831
            milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
832
            t = threading.Thread(target=search, args=(milvus,))
833
            threads.append(t)
834
            t.start()
835
            time.sleep(0.2)
836
        for t in threads:
837
            t.join()
838
839
    @pytest.mark.level(2)
840
    @pytest.mark.timeout(30)
841
    def test_search_concurrent_multithreads_single_connection(self, connect, args):
842
        '''
843
        target: test concurrent search with multiprocessess
844
        method: search with 10 processes, each process uses dependent connection
845
        expected: status ok and the returned vectors should be query_records
846
        '''
847
        nb = 100
848
        top_k = 10
849
        threads_num = 4
850
        threads = []
851
        collection = gen_unique_str(uid)
852
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
853
        # create collection
854
        milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable get_milvus does not seem to be defined.
Loading history...
855
        milvus.create_collection(collection, default_fields)
856
        entities, ids = init_data(milvus, collection)
857
858
        def search(milvus):
859
            res = connect.search(collection, default_query)
860
            assert len(res) == 1
861
            assert res[0]._entities[0].id in ids
862
            assert res[0]._distances[0] < epsilon
863
864
        for i in range(threads_num):
865
            t = threading.Thread(target=search, args=(milvus,))
866
            threads.append(t)
867
            t.start()
868
            time.sleep(0.2)
869
        for t in threads:
870
            t.join()
871
872
    @pytest.mark.level(2)
873
    def test_search_multi_collections(self, connect, args):
874
        '''
875
        target: test search multi collections of L2
876
        method: add vectors into 10 collections, and search
877
        expected: search status ok, the length of result
878
        '''
879
        num = 10
880
        top_k = 10
881
        nq = 20
882
        for i in range(num):
883
            collection = gen_unique_str(uid + str(i))
884
            connect.create_collection(collection, default_fields)
885
            entities, ids = init_data(connect, collection)
886
            assert len(ids) == default_nb
887
            query, vecs = gen_query_vectors(field_name, entities, top_k, nq, search_params=search_param)
888
            res = connect.search(collection, query)
889
            assert len(res) == nq
890
            for i in range(nq):
891
                assert check_id_result(res[i], ids[i])
892
                assert res[i]._distances[0] < epsilon
893
                assert res[i]._distances[1] > epsilon
894
895
896
class TestSearchDSL(object):
897
    """
898
    ******************************************************************
899
    #  The following cases are used to build invalid query expr
900
    ******************************************************************
901
    """
902
903
    def test_query_no_must(self, connect, collection):
904
        '''
905
        method: build query without must expr
906
        expected: error raised
907
        '''
908
        # entities, ids = init_data(connect, collection)
909
        query = update_query_expr(default_query, keep_old=False)
910
        with pytest.raises(Exception) as e:
911
            res = connect.search(collection, query)
912
913
    def test_query_no_vector_term_only(self, connect, collection):
914
        '''
915
        method: build query without vector only term
916
        expected: error raised
917
        '''
918
        # entities, ids = init_data(connect, collection)
919
        expr = {
920
            "must": [gen_default_term_expr]
921
        }
922
        query = update_query_expr(default_query, keep_old=False, expr=expr)
923
        with pytest.raises(Exception) as e:
924
            res = connect.search(collection, query)
925
926
    def test_query_no_vector_range_only(self, connect, collection):
927
        '''
928
        method: build query without vector only range
929
        expected: error raised
930
        '''
931
        # entities, ids = init_data(connect, collection)
932
        expr = {
933
            "must": [gen_default_range_expr]
934
        }
935
        query = update_query_expr(default_query, keep_old=False, expr=expr)
936
        with pytest.raises(Exception) as e:
937
            res = connect.search(collection, query)
938
939
    def test_query_vector_only(self, connect, collection):
940
        entities, ids = init_data(connect, collection)
941
        res = connect.search(collection, default_query)
942
        assert len(res) == nq
943
        assert len(res[0]) == default_top_k
944
945
    def test_query_wrong_format(self, connect, collection):
946
        '''
947
        method: build query without must expr, with wrong expr name
948
        expected: error raised
949
        '''
950
        # entities, ids = init_data(connect, collection)
951
        expr = {
952
            "must1": [gen_default_term_expr]
953
        }
954
        query = update_query_expr(default_query, keep_old=False, expr=expr)
955
        with pytest.raises(Exception) as e:
956
            res = connect.search(collection, query)
957
958
    def test_query_empty(self, connect, collection):
959
        '''
960
        method: search with empty query
961
        expected: error raised
962
        '''
963
        query = {}
964
        with pytest.raises(Exception) as e:
965
            res = connect.search(collection, query)
966
967
    """
968
    ******************************************************************
969
    #  The following cases are used to build valid query expr
970
    ******************************************************************
971
    """
972
973
    @pytest.mark.level(2)
974
    def test_query_term_value_not_in(self, connect, collection):
975
        '''
976
        method: build query with vector and term expr, with no term can be filtered
977
        expected: filter pass
978
        '''
979
        entities, ids = init_data(connect, collection)
980
        expr = {
981
            "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[100000])]}
982
        query = update_query_expr(default_query, expr=expr)
983
        res = connect.search(collection, query)
984
        assert len(res) == nq
985
        assert len(res[0]) == 0
986
        # TODO:
987
988
    # TODO:
989
    @pytest.mark.level(2)
990
    def test_query_term_value_all_in(self, connect, collection):
991
        '''
992
        method: build query with vector and term expr, with all term can be filtered
993
        expected: filter pass
994
        '''
995
        entities, ids = init_data(connect, collection)
996
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1])]}
997
        query = update_query_expr(default_query, expr=expr)
998
        res = connect.search(collection, query)
999
        assert len(res) == nq
1000
        assert len(res[0]) == 1
1001
        # TODO:
1002
1003
    # TODO:
1004 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1005
    def test_query_term_values_not_in(self, connect, collection):
1006
        '''
1007
        method: build query with vector and term expr, with no term can be filtered
1008
        expected: filter pass
1009
        '''
1010
        entities, ids = init_data(connect, collection)
1011
        expr = {"must": [gen_default_vector_expr(default_query),
1012
                         gen_default_term_expr(values=[i for i in range(100000, 100010)])]}
1013
        query = update_query_expr(default_query, expr=expr)
1014
        res = connect.search(collection, query)
1015
        assert len(res) == nq
1016
        assert len(res[0]) == 0
1017
        # TODO:
1018
1019
    def test_query_term_values_all_in(self, connect, collection):
1020
        '''
1021
        method: build query with vector and term expr, with all term can be filtered
1022
        expected: filter pass
1023
        '''
1024
        entities, ids = init_data(connect, collection)
1025
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr()]}
1026
        query = update_query_expr(default_query, expr=expr)
1027
        res = connect.search(collection, query)
1028
        assert len(res) == nq
1029
        assert len(res[0]) == default_top_k
1030
        # TODO:
1031
1032 View Code Duplication
    def test_query_term_values_parts_in(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1033
        '''
1034
        method: build query with vector and term expr, with parts of term can be filtered
1035
        expected: filter pass
1036
        '''
1037
        entities, ids = init_data(connect, collection)
1038
        expr = {"must": [gen_default_vector_expr(default_query),
1039
                         gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])]}
1040
        query = update_query_expr(default_query, expr=expr)
1041
        res = connect.search(collection, query)
1042
        assert len(res) == nq
1043
        assert len(res[0]) == default_top_k
1044
        # TODO:
1045
1046
    # TODO:
1047 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1048
    def test_query_term_values_repeat(self, connect, collection):
1049
        '''
1050
        method: build query with vector and term expr, with the same values
1051
        expected: filter pass
1052
        '''
1053
        entities, ids = init_data(connect, collection)
1054
        expr = {
1055
            "must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[1 for i in range(1, default_nb)])]}
1056
        query = update_query_expr(default_query, expr=expr)
1057
        res = connect.search(collection, query)
1058
        assert len(res) == nq
1059
        assert len(res[0]) == 1
1060
        # TODO:
1061
1062
    def test_query_term_value_empty(self, connect, collection):
1063
        '''
1064
        method: build query with term value empty
1065
        expected: return null
1066
        '''
1067
        expr = {"must": [gen_default_vector_expr(default_query), gen_default_term_expr(values=[])]}
1068
        query = update_query_expr(default_query, expr=expr)
1069
        res = connect.search(collection, query)
1070
        assert len(res) == nq
1071
        assert len(res[0]) == 0
1072
1073
    """
1074
    ******************************************************************
1075
    #  The following cases are used to build invalid term query expr
1076
    ******************************************************************
1077
    """
1078
1079
    # TODO
1080
    @pytest.mark.level(2)
1081
    def test_query_term_key_error(self, connect, collection):
1082
        '''
1083
        method: build query with term key error
1084
        expected: Exception raised
1085
        '''
1086
        expr = {"must": [gen_default_vector_expr(default_query),
1087
                         gen_default_term_expr(keyword="terrm", values=[i for i in range(default_nb // 2)])]}
1088
        query = update_query_expr(default_query, expr=expr)
1089
        with pytest.raises(Exception) as e:
1090
            res = connect.search(collection, query)
1091
1092
    @pytest.fixture(
1093
        scope="function",
1094
        params=gen_invalid_term()
1095
    )
1096
    def get_invalid_term(self, request):
1097
        return request.param
1098
1099
    @pytest.mark.level(2)
1100
    def test_query_term_wrong_format(self, connect, collection, get_invalid_term):
1101
        '''
1102
        method: build query with wrong format term
1103
        expected: Exception raised
1104
        '''
1105
        entities, ids = init_data(connect, collection)
1106
        term = get_invalid_term
1107
        expr = {"must": [gen_default_vector_expr(default_query), term]}
1108
        query = update_query_expr(default_query, expr=expr)
1109
        with pytest.raises(Exception) as e:
1110
            res = connect.search(collection, query)
1111
1112
    # TODO
1113
    @pytest.mark.level(2)
1114
    def test_query_term_field_named_term(self, connect, collection):
1115
        '''
1116
        method: build query with field named "term"
1117
        expected: error raised
1118
        '''
1119
        term_fields = add_field_default(default_fields, field_name="term")
1120
        collection_term = gen_unique_str("term")
1121
        connect.create_collection(collection_term, term_fields)
1122
        term_entities = add_field(entities, field_name="term")
1123
        ids = connect.insert(collection_term, term_entities)
1124
        assert len(ids) == default_nb
1125
        connect.flush([collection_term])
1126
        count = connect.count_entities(collection_term)
1127
        assert count == default_nb
1128
        term_param = {"term": {"term": {"values": [i for i in range(default_nb // 2)]}}}
1129
        expr = {"must": [gen_default_vector_expr(default_query),
1130
                         term_param]}
1131
        query = update_query_expr(default_query, expr=expr)
1132
        res = connect.search(collection_term, query)
1133
        assert len(res) == nq
1134
        assert len(res[0]) == default_top_k
1135
        connect.drop_collection(collection_term)
1136
1137
    @pytest.mark.level(2)
1138
    def test_query_term_one_field_not_existed(self, connect, collection):
1139
        '''
1140
        method: build query with two fields term, one of it not existed
1141
        expected: exception raised
1142
        '''
1143
        entities, ids = init_data(connect, collection)
1144
        term = gen_default_term_expr()
1145
        term["term"].update({"a": [0]})
1146
        expr = {"must": [gen_default_vector_expr(default_query), term]}
1147
        query = update_query_expr(default_query, expr=expr)
1148
        with pytest.raises(Exception) as e:
1149
            res = connect.search(collection, query)
1150
1151
    """
1152
    ******************************************************************
1153
    #  The following cases are used to build valid range query expr
1154
    ******************************************************************
1155
    """
1156
1157
    # TODO
1158
    def test_query_range_key_error(self, connect, collection):
1159
        '''
1160
        method: build query with range key error
1161
        expected: Exception raised
1162
        '''
1163
        range = gen_default_range_expr(keyword="ranges")
1164
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1165
        query = update_query_expr(default_query, expr=expr)
1166
        with pytest.raises(Exception) as e:
1167
            res = connect.search(collection, query)
1168
1169
    @pytest.fixture(
1170
        scope="function",
1171
        params=gen_invalid_range()
1172
    )
1173
    def get_invalid_range(self, request):
1174
        return request.param
1175
1176
    # TODO
1177
    @pytest.mark.level(2)
1178
    def test_query_range_wrong_format(self, connect, collection, get_invalid_range):
1179
        '''
1180
        method: build query with wrong format range
1181
        expected: Exception raised
1182
        '''
1183
        entities, ids = init_data(connect, collection)
1184
        range = get_invalid_range
1185
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1186
        query = update_query_expr(default_query, expr=expr)
1187
        with pytest.raises(Exception) as e:
1188
            res = connect.search(collection, query)
1189
1190
    @pytest.mark.level(2)
1191
    def test_query_range_string_ranges(self, connect, collection):
1192
        '''
1193
        method: build query with invalid ranges
1194
        expected: raise Exception
1195
        '''
1196
        entities, ids = init_data(connect, collection)
1197
        ranges = {"GT": "0", "LT": "1000"}
1198
        range = gen_default_range_expr(ranges=ranges)
1199
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1200
        query = update_query_expr(default_query, expr=expr)
1201
        with pytest.raises(Exception) as e:
1202
            res = connect.search(collection, query)
1203
1204
    @pytest.mark.level(2)
1205
    def test_query_range_invalid_ranges(self, connect, collection):
1206
        '''
1207
        method: build query with invalid ranges
1208
        expected: 0
1209
        '''
1210
        entities, ids = init_data(connect, collection)
1211
        ranges = {"GT": default_nb, "LT": 0}
1212
        range = gen_default_range_expr(ranges=ranges)
1213
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1214
        query = update_query_expr(default_query, expr=expr)
1215
        res = connect.search(collection, query)
1216
        assert len(res[0]) == 0
1217
1218
    @pytest.fixture(
1219
        scope="function",
1220
        params=gen_valid_ranges()
1221
    )
1222
    def get_valid_ranges(self, request):
1223
        return request.param
1224
1225
    @pytest.mark.level(2)
1226
    def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
1227
        '''
1228
        method: build query with valid ranges
1229
        expected: pass
1230
        '''
1231
        entities, ids = init_data(connect, collection)
1232
        ranges = get_valid_ranges
1233
        range = gen_default_range_expr(ranges=ranges)
1234
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1235
        query = update_query_expr(default_query, expr=expr)
1236
        res = connect.search(collection, query)
1237
        assert len(res) == nq
1238
        assert len(res[0]) == default_top_k
1239
1240
    def test_query_range_one_field_not_existed(self, connect, collection):
1241
        '''
1242
        method: build query with two fields ranges, one of fields not existed
1243
        expected: exception raised
1244
        '''
1245
        entities, ids = init_data(connect, collection)
1246
        range = gen_default_range_expr()
1247
        range["range"].update({"a": {"GT": 1, "LT": default_nb // 2}})
1248
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1249
        query = update_query_expr(default_query, expr=expr)
1250
        with pytest.raises(Exception) as e:
1251
            res = connect.search(collection, query)
1252
1253
    """
1254
    ************************************************************************
1255
    #  The following cases are used to build query expr multi range and term
1256
    ************************************************************************
1257
    """
1258
1259
    # TODO
1260 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1261
    def test_query_multi_term_has_common(self, connect, collection):
1262
        '''
1263
        method: build query with multi term with same field, and values has common
1264
        expected: pass
1265
        '''
1266
        entities, ids = init_data(connect, collection)
1267
        term_first = gen_default_term_expr()
1268
        term_second = gen_default_term_expr(values=[i for i in range(default_nb // 3)])
1269
        expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
1270
        query = update_query_expr(default_query, expr=expr)
1271
        res = connect.search(collection, query)
1272
        assert len(res) == nq
1273
        assert len(res[0]) == default_top_k
1274
1275
    # TODO
1276 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1277
    def test_query_multi_term_no_common(self, connect, collection):
1278
        '''
1279
         method: build query with multi range with same field, and ranges no common
1280
         expected: pass
1281
        '''
1282
        entities, ids = init_data(connect, collection)
1283
        term_first = gen_default_term_expr()
1284
        term_second = gen_default_term_expr(values=[i for i in range(default_nb // 2, default_nb + default_nb // 2)])
1285
        expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
1286
        query = update_query_expr(default_query, expr=expr)
1287
        res = connect.search(collection, query)
1288
        assert len(res) == nq
1289
        assert len(res[0]) == 0
1290
1291
    # TODO
1292 View Code Duplication
    def test_query_multi_term_different_fields(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1293
        '''
1294
         method: build query with multi range with same field, and ranges no common
1295
         expected: pass
1296
        '''
1297
        entities, ids = init_data(connect, collection)
1298
        term_first = gen_default_term_expr()
1299
        term_second = gen_default_term_expr(field="float", values=[float(i) for i in range(default_nb // 2, default_nb)])
1300
        expr = {"must": [gen_default_vector_expr(default_query), term_first, term_second]}
1301
        query = update_query_expr(default_query, expr=expr)
1302
        res = connect.search(collection, query)
1303
        assert len(res) == nq
1304
        assert len(res[0]) == 0
1305
1306
    # TODO
1307
    @pytest.mark.level(2)
1308
    def test_query_single_term_multi_fields(self, connect, collection):
1309
        '''
1310
        method: build query with multi term, different field each term
1311
        expected: pass
1312
        '''
1313
        entities, ids = init_data(connect, collection)
1314
        term_first = {"int64": {"values": [i for i in range(default_nb // 2)]}}
1315
        term_second = {"float": {"values": [float(i) for i in range(default_nb // 2, default_nb)]}}
1316
        term = update_term_expr({"term": {}}, [term_first, term_second])
1317
        expr = {"must": [gen_default_vector_expr(default_query), term]}
1318
        query = update_query_expr(default_query, expr=expr)
1319
        with pytest.raises(Exception) as e:
1320
            res = connect.search(collection, query)
1321
1322
    # TODO
1323
    @pytest.mark.level(2)
1324
    def test_query_multi_range_has_common(self, connect, collection):
1325
        '''
1326
        method: build query with multi range with same field, and ranges has common
1327
        expected: pass
1328
        '''
1329
        entities, ids = init_data(connect, collection)
1330
        range_one = gen_default_range_expr()
1331
        range_two = gen_default_range_expr(ranges={"GT": 1, "LT": default_nb // 3})
1332
        expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]}
1333
        query = update_query_expr(default_query, expr=expr)
1334
        res = connect.search(collection, query)
1335
        assert len(res) == nq
1336
        assert len(res[0]) == default_top_k
1337
1338
    # TODO
1339
    @pytest.mark.level(2)
1340
    def test_query_multi_range_no_common(self, connect, collection):
1341
        '''
1342
         method: build query with multi range with same field, and ranges no common
1343
        expected: pass
1344
        '''
1345
        entities, ids = init_data(connect, collection)
1346
        range_one = gen_default_range_expr()
1347
        range_two = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb})
1348
        expr = {"must": [gen_default_vector_expr(default_query), range_one, range_two]}
1349
        query = update_query_expr(default_query, expr=expr)
1350
        res = connect.search(collection, query)
1351
        assert len(res) == nq
1352
        assert len(res[0]) == 0
1353
1354
    # TODO
1355
    @pytest.mark.level(2)
1356
    def test_query_multi_range_different_fields(self, connect, collection):
1357
        '''
1358
        method: build query with multi range, different field each range
1359
        expected: pass
1360
        '''
1361
        entities, ids = init_data(connect, collection)
1362
        range_first = gen_default_range_expr()
1363
        range_second = gen_default_range_expr(field="float", ranges={"GT": default_nb // 2, "LT": default_nb})
1364
        expr = {"must": [gen_default_vector_expr(default_query), range_first, range_second]}
1365
        query = update_query_expr(default_query, expr=expr)
1366
        res = connect.search(collection, query)
1367
        assert len(res) == nq
1368
        assert len(res[0]) == 0
1369
1370
    # TODO
1371
    @pytest.mark.level(2)
1372
    def test_query_single_range_multi_fields(self, connect, collection):
1373
        '''
1374
        method: build query with multi range, different field each range
1375
        expected: pass
1376
        '''
1377
        entities, ids = init_data(connect, collection)
1378
        range_first = {"int64": {"GT": 0, "LT": default_nb // 2}}
1379
        range_second = {"float": {"GT": default_nb / 2, "LT": float(default_nb)}}
1380
        range = update_range_expr({"range": {}}, [range_first, range_second])
1381
        expr = {"must": [gen_default_vector_expr(default_query), range]}
1382
        query = update_query_expr(default_query, expr=expr)
1383
        with pytest.raises(Exception) as e:
1384
            res = connect.search(collection, query)
1385
1386
    """
1387
    ******************************************************************
1388
    #  The following cases are used to build query expr both term and range
1389
    ******************************************************************
1390
    """
1391
1392
    # TODO
1393 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1394
    def test_query_single_term_range_has_common(self, connect, collection):
1395
        '''
1396
        method: build query with single term single range
1397
        expected: pass
1398
        '''
1399
        entities, ids = init_data(connect, collection)
1400
        term = gen_default_term_expr()
1401
        range = gen_default_range_expr(ranges={"GT": -1, "LT": default_nb // 2})
1402
        expr = {"must": [gen_default_vector_expr(default_query), term, range]}
1403
        query = update_query_expr(default_query, expr=expr)
1404
        res = connect.search(collection, query)
1405
        assert len(res) == nq
1406
        assert len(res[0]) == default_top_k
1407
1408
    # TODO
1409 View Code Duplication
    def test_query_single_term_range_no_common(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1410
        '''
1411
        method: build query with single term single range
1412
        expected: pass
1413
        '''
1414
        entities, ids = init_data(connect, collection)
1415
        term = gen_default_term_expr()
1416
        range = gen_default_range_expr(ranges={"GT": default_nb // 2, "LT": default_nb})
1417
        expr = {"must": [gen_default_vector_expr(default_query), term, range]}
1418
        query = update_query_expr(default_query, expr=expr)
1419
        res = connect.search(collection, query)
1420
        assert len(res) == nq
1421
        assert len(res[0]) == 0
1422
1423
    """
1424
    ******************************************************************
1425
    #  The following cases are used to build multi vectors query expr
1426
    ******************************************************************
1427
    """
1428
1429
    # TODO
1430
    def test_query_multi_vectors_same_field(self, connect, collection):
1431
        '''
1432
        method: build query with two vectors same field
1433
        expected: error raised
1434
        '''
1435
        entities, ids = init_data(connect, collection)
1436
        vector1 = default_query
1437
        vector2 = gen_query_vectors(field_name, entities, default_top_k, nq=2)
1438
        expr = {
1439
            "must": [vector1, vector2]
1440
        }
1441
        query = update_query_expr(default_query, expr=expr)
1442
        with pytest.raises(Exception) as e:
1443
            res = connect.search(collection, query)
1444
1445
1446
class TestSearchDSLBools(object):
1447
    """
1448
    ******************************************************************
1449
    #  The following cases are used to build invalid query expr
1450
    ******************************************************************
1451
    """
1452
1453
    @pytest.mark.level(2)
1454
    def test_query_no_bool(self, connect, collection):
1455
        '''
1456
        method: build query without bool expr
1457
        expected: error raised
1458
        '''
1459
        entities, ids = init_data(connect, collection)
1460
        expr = {"bool1": {}}
1461
        query = expr
1462
        with pytest.raises(Exception) as e:
1463
            res = connect.search(collection, query)
1464
1465
    def test_query_should_only_term(self, connect, collection):
1466
        '''
1467
        method: build query without must, with should.term instead
1468
        expected: error raised
1469
        '''
1470
        expr = {"should": gen_default_term_expr}
1471
        query = update_query_expr(default_query, keep_old=False, expr=expr)
1472
        with pytest.raises(Exception) as e:
1473
            res = connect.search(collection, query)
1474
1475
    def test_query_should_only_vector(self, connect, collection):
1476
        '''
1477
        method: build query without must, with should.vector instead
1478
        expected: error raised
1479
        '''
1480
        expr = {"should": default_query["bool"]["must"]}
1481
        query = update_query_expr(default_query, keep_old=False, expr=expr)
1482
        with pytest.raises(Exception) as e:
1483
            res = connect.search(collection, query)
1484
1485
    def test_query_must_not_only_term(self, connect, collection):
1486
        '''
1487
        method: build query without must, with must_not.term instead
1488
        expected: error raised
1489
        '''
1490
        expr = {"must_not": gen_default_term_expr}
1491
        query = update_query_expr(default_query, keep_old=False, expr=expr)
1492
        with pytest.raises(Exception) as e:
1493
            res = connect.search(collection, query)
1494
1495
    def test_query_must_not_vector(self, connect, collection):
1496
        '''
1497
        method: build query without must, with must_not.vector instead
1498
        expected: error raised
1499
        '''
1500
        expr = {"must_not": default_query["bool"]["must"]}
1501
        query = update_query_expr(default_query, keep_old=False, expr=expr)
1502
        with pytest.raises(Exception) as e:
1503
            res = connect.search(collection, query)
1504
1505
    def test_query_must_should(self, connect, collection):
1506
        '''
1507
        method: build query must, and with should.term
1508
        expected: error raised
1509
        '''
1510
        expr = {"should": gen_default_term_expr}
1511
        query = update_query_expr(default_query, keep_old=True, expr=expr)
1512
        with pytest.raises(Exception) as e:
1513
            res = connect.search(collection, query)
1514
1515
1516
"""
1517
******************************************************************
1518
#  The following cases are used to test `search` function 
1519
#  with invalid collection_name, or invalid query expr
1520
******************************************************************
1521
"""
1522
1523
1524
class TestSearchInvalid(object):
1525
    """
1526
    Test search collection with invalid collection names
1527
    """
1528
1529
    @pytest.fixture(
1530
        scope="function",
1531
        params=gen_invalid_strs()
1532
    )
1533
    def get_collection_name(self, request):
1534
        yield request.param
1535
1536
    @pytest.fixture(
1537
        scope="function",
1538
        params=gen_invalid_strs()
1539
    )
1540
    def get_invalid_tag(self, request):
1541
        yield request.param
1542
1543
    @pytest.fixture(
1544
        scope="function",
1545
        params=gen_invalid_strs()
1546
    )
1547
    def get_invalid_field(self, request):
1548
        yield request.param
1549
1550
    @pytest.fixture(
1551
        scope="function",
1552
        params=gen_simple_index()
1553
    )
1554
    def get_simple_index(self, request, connect):
1555
        if str(connect._cmd("mode")) == "CPU":
1556
            if request.param["index_type"] in index_cpu_not_support():
1557
                pytest.skip("sq8h not support in CPU mode")
1558
        return request.param
1559
1560
    @pytest.mark.level(2)
1561
    def test_search_with_invalid_collection(self, connect, get_collection_name):
1562
        collection_name = get_collection_name
1563
        with pytest.raises(Exception) as e:
1564
            res = connect.search(collection_name, default_query)
1565
1566
    # TODO(yukun)
1567
    @pytest.mark.level(2)
1568
    def _test_search_with_invalid_tag(self, connect, collection):
1569
        tag = " "
1570
        with pytest.raises(Exception) as e:
1571
            res = connect.search(collection, default_query, partition_tags=tag)
1572
1573
    @pytest.mark.level(2)
1574
    def test_search_with_invalid_field_name(self, connect, collection, get_invalid_field):
1575
        fields = [get_invalid_field]
1576
        with pytest.raises(Exception) as e:
1577
            res = connect.search(collection, default_query, fields=fields)
1578
1579
    @pytest.mark.level(1)
1580
    def test_search_with_not_existed_field_name(self, connect, collection):
1581
        fields = [gen_unique_str("field_name")]
1582
        with pytest.raises(Exception) as e:
1583
            res = connect.search(collection, default_query, fields=fields)
1584
1585
    """
1586
    Test search collection with invalid query
1587
    """
1588
1589
    @pytest.fixture(
1590
        scope="function",
1591
        params=gen_invalid_ints()
1592
    )
1593
    def get_top_k(self, request):
1594
        yield request.param
1595
1596
    @pytest.mark.level(1)
1597
    def test_search_with_invalid_top_k(self, connect, collection, get_top_k):
1598
        '''
1599
        target: test search function, with the wrong top_k
1600
        method: search with top_k
1601
        expected: raise an error, and the connection is normal
1602
        '''
1603
        top_k = get_top_k
1604
        default_query["bool"]["must"][0]["vector"][field_name]["topk"] = top_k
1605
        with pytest.raises(Exception) as e:
1606
            res = connect.search(collection, default_query)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable default_query does not seem to be defined.
Loading history...
1607
1608
    """
1609
    Test search collection with invalid search params
1610
    """
1611
1612
    @pytest.fixture(
1613
        scope="function",
1614
        params=gen_invaild_search_params()
1615
    )
1616
    def get_search_params(self, request):
1617
        yield request.param
1618
1619 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1620
    def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
1621
        '''
1622
        target: test search function, with the wrong nprobe
1623
        method: search with nprobe
1624
        expected: raise an error, and the connection is normal
1625
        '''
1626
        search_params = get_search_params
1627
        index_type = get_simple_index["index_type"]
1628
        if index_type in ["FLAT"]:
1629
            pytest.skip("skip in FLAT index")
1630
        if index_type != search_params["index_type"]:
1631
            pytest.skip("skip if index_type not matched")
1632
        entities, ids = init_data(connect, collection)
1633
        connect.create_index(collection, field_name, get_simple_index)
1634
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params=search_params["search_params"])
1635
        with pytest.raises(Exception) as e:
1636
            res = connect.search(collection, query)
1637
1638
    @pytest.mark.level(2)
1639
    def test_search_with_invalid_params_binary(self, connect, binary_collection):
1640
        '''
1641
        target: test search function, with the wrong nprobe
1642
        method: search with nprobe
1643
        expected: raise an error, and the connection is normal
1644
        '''
1645
        nq = 1
1646
        index_type = "BIN_IVF_FLAT"
1647
        int_vectors, entities, ids = init_binary_data(connect, binary_collection)
1648
        query_int_vectors, query_entities, tmp_ids = init_binary_data(connect, binary_collection, nb=1, insert=False)
1649
        connect.create_index(binary_collection, binary_field_name, {"index_type": index_type, "metric_type": "JACCARD", "params": {"nlist": 128}})
1650
        query, vecs = gen_query_vectors(binary_field_name, query_entities, default_top_k, nq, search_params={"nprobe": 0}, metric_type="JACCARD")
1651
        with pytest.raises(Exception) as e:
1652
            res = connect.search(binary_collection, query)
1653
1654 View Code Duplication
    @pytest.mark.level(2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1655
    def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
1656
        '''
1657
        target: test search function, with empty search params
1658
        method: search with params
1659
        expected: raise an error, and the connection is normal
1660
        '''
1661
        index_type = get_simple_index["index_type"]
1662
        if args["handler"] == "HTTP":
1663
            pytest.skip("skip in http mode")
1664
        if index_type == "FLAT":
1665
            pytest.skip("skip in FLAT index")
1666
        entities, ids = init_data(connect, collection)
1667
        connect.create_index(collection, field_name, get_simple_index)
1668
        query, vecs = gen_query_vectors(field_name, entities, default_top_k, 1, search_params={})
1669
        with pytest.raises(Exception) as e:
1670
            res = connect.search(collection, query)
1671
1672
1673
def check_id_result(result, id):
1674
    limit_in = 5
1675
    ids = [entity.id for entity in result]
1676
    if len(result) >= limit_in:
1677
        return id in ids[:limit_in]
1678
    else:
1679
        return id in ids
1680