Passed
Push — master ( 52a593...841a75 )
by
unknown
02:10
created

TestCreateCollectionInvalid.get_segment_size()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
import pdb
2
import copy
3
import logging
4
import itertools
5
from time import sleep
6
import threading
7
from multiprocessing import Process
8
import sklearn.preprocessing
9
10
import pytest
11
from utils import *
12
13
nb = 1
14
dim = 128
15
collection_id = "create_collection"
16
default_segment_row_count = 100000
17
drop_collection_interval_time = 3
18
segment_row_count = 5000
19
default_fields = gen_default_fields() 
20
entities = gen_entities(nb)
21
22
class TestCreateCollection:
23
24
    """
25
    ******************************************************************
26
      The following cases are used to test `create_collection` function
27
    ******************************************************************
28
    """
29
    @pytest.fixture(
30
        scope="function",
31
        params=gen_single_filter_fields()
32
    )
33
    def get_filter_field(self, request):
34
        yield request.param
35
36
    @pytest.fixture(
37
        scope="function",
38
        params=gen_single_vector_fields()
39
    )
40
    def get_vector_field(self, request):
41
        yield request.param
42
43
    @pytest.fixture(
44
        scope="function",
45
        params=gen_segment_row_counts()
46
    )
47
    def get_segment_row_count(self, request):
48
        yield request.param
49
50
    def test_create_collection_fields(self, connect, get_filter_field, get_vector_field):
51
        '''
52
        target: test create normal collection with different fields
53
        method: create collection with diff fields: metric/field_type/...
54
        expected: no exception raised
55
        '''
56
        filter_field = get_filter_field
57
        logging.getLogger().info(filter_field)
58
        vector_field = get_vector_field
59
        collection_name = gen_unique_str(collection_id)
60
        fields = {
61
                "fields": [filter_field, vector_field],
62
                "segment_row_count": segment_row_count
63
        }
64
        logging.getLogger().info(fields)
65
        connect.create_collection(collection_name, fields)
66
        assert connect.has_collection(collection_name)
67
68
    # TODO
69
    def test_create_collection_fields_create_index(self, connect, get_filter_field, get_vector_field):
70
        '''
71
        target: test create normal collection with different fields
72
        method: create collection with diff fields: metric/field_type/...
73
        expected: no exception raised
74
        '''
75
        filter_field = get_filter_field
76
        vector_field = get_vector_field
77
        collection_name = gen_unique_str(collection_id)
78
        fields = {
79
                "fields": [filter_field, vector_field],
80
                "segment_row_count": segment_row_count
81
        }
82
        connect.create_collection(collection_name, fields)
83
        assert connect.has_collection(collection_name)
84
        
85
    def test_create_collection_segment_row_count(self, connect, get_segment_row_count):
86
        '''
87
        target: test create normal collection with different fields
88
        method: create collection with diff segment_row_count
89
        expected: no exception raised
90
        '''
91
        collection_name = gen_unique_str(collection_id)
92
        fields = copy.deepcopy(default_fields)
93
        fields["segment_row_count"] = get_segment_row_count
94
        connect.create_collection(collection_name, fields)
95
        assert connect.has_collection(collection_name)
96
97
    def test_create_collection_auto_flush_disabled(self, connect):
98
        '''
99
        target: test create normal collection, with large auto_flush_interval
100
        method: create collection with corrent params
101
        expected: create status return ok
102
        '''
103
        disable_flush(connect)
104
        collection_name = gen_unique_str(collection_id)
105
        try:
106
            connect.create_collection(collection_name, default_fields)
107
        finally:
108
            enable_flush(connect)
109
        # pdb.set_trace()
110
111
    def test_create_collection_after_insert(self, connect, collection):
112
        '''
113
        target: test insert vector, then create collection again
114
        method: insert vector and create collection
115
        expected: error raised
116
        '''
117
        # pdb.set_trace()
118
        connect.insert(collection, entities)
119
120
        with pytest.raises(Exception) as e:
121
            connect.create_collection(collection, default_fields)
122
123
    def test_create_collection_after_insert_flush(self, connect, collection):
124
        '''
125
        target: test insert vector, then create collection again
126
        method: insert vector and create collection
127
        expected: error raised
128
        '''
129
        connect.insert(collection, entities)
130
        connect.flush([collection])
131
        with pytest.raises(Exception) as e:
132
            connect.create_collection(collection, default_fields)
133
134
    # TODO: assert exception
135
    @pytest.mark.level(2)
136
    def test_create_collection_without_connection(self, dis_connect):
137
        '''
138
        target: test create collection, without connection
139
        method: create collection with correct params, with a disconnected instance
140
        expected: create raise exception
141
        '''
142
        collection_name = gen_unique_str(collection_id)
143
        with pytest.raises(Exception) as e:
144
            connect.create_collection(collection_name, default_fields)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable connect does not seem to be defined.
Loading history...
145
146
    def test_create_collection_existed(self, connect):
147
        '''
148
        target: test create collection but the collection name have already existed
149
        method: create collection with the same collection_name
150
        expected: create status return not ok
151
        '''
152
        collection_name = gen_unique_str(collection_id)
153
        connect.create_collection(collection_name, default_fields)
154
        with pytest.raises(Exception) as e:
155
            connect.create_collection(collection_name, default_fields)
156
157
    @pytest.mark.level(2)
158
    def test_create_collection_multithread(self, connect):
159
        '''
160
        target: test create collection with multithread
161
        method: create collection using multithread, 
162
        expected: collections are created
163
        '''
164
        threads_num = 8 
165
        threads = []
166
        collection_names = []
167
168
        def create():
169
            collection_name = gen_unique_str(collection_id)
170
            collection_names.append(collection_name)
171
            connect.create_collection(collection_name, default_fields)
172
        for i in range(threads_num):
173
            t = threading.Thread(target=create, args=())
174
            threads.append(t)
175
            t.start()
176
            time.sleep(0.2)
177
        for t in threads:
178
            t.join()
179
        
180
        res = connect.list_collections()
181
        for item in collection_names:
182
            assert item in res
183
184
185
class TestCreateCollectionInvalid(object):
186
    """
187
    Test creating collections with invalid params
188
    """
189
    @pytest.fixture(
190
        scope="function",
191
        params=gen_invalid_metric_types()
192
    )
193
    def get_metric_type(self, request):
194
        yield request.param
195
196
    @pytest.fixture(
197
        scope="function",
198
        params=gen_invalid_ints()
199
    )
200
    def get_segment_row_count(self, request):
201
        yield request.param
202
203
    @pytest.fixture(
204
        scope="function",
205
        params=gen_invalid_ints()
206
    )
207
    def get_dim(self, request):
208
        yield request.param
209
210
    @pytest.fixture(
211
        scope="function",
212
        params=gen_invalid_strs()
213
    )
214
    def get_invalid_string(self, request):
215
        yield request.param
216
217
    @pytest.fixture(
218
        scope="function",
219
        params=gen_invalid_field_types()
220
    )
221
    def get_field_type(self, request):
222
        yield request.param
223
224
    @pytest.mark.level(2)
225
    def test_create_collection_with_invalid_segment_row_count(self, connect, get_segment_row_count):
226
        collection_name = gen_unique_str()
227
        fields = copy.deepcopy(default_fields)
228
        fields["segment_row_count"] = get_segment_row_count
229
        with pytest.raises(Exception) as e:
230
            connect.create_collection(collection_name, fields)
231
232
    # @pytest.mark.level(2)
233
    # def test_create_collection_with_invalid_metric_type(self, connect, get_metric_type):
234
    #     collection_name = gen_unique_str()
235
    #     fields = copy.deepcopy(default_fields)
236
    #     fields["fields"][-1]["params"]["metric_type"] = get_metric_type
237
    #     with pytest.raises(Exception) as e:
238
    #         connect.create_collection(collection_name, fields)
239
240
    @pytest.mark.level(2)
241
    def test_create_collection_with_invalid_dimension(self, connect, get_dim):
242
        dimension = get_dim
243
        collection_name = gen_unique_str()
244
        fields = copy.deepcopy(default_fields)
245
        fields["fields"][-1]["params"]["dim"] = dimension
246
        with pytest.raises(Exception) as e:
247
             connect.create_collection(collection_name, fields)
248
249
    @pytest.mark.level(2)
250
    def test_create_collection_with_invalid_collectionname(self, connect, get_invalid_string):
251
        collection_name = get_invalid_string
252
        with pytest.raises(Exception) as e:
253
            connect.create_collection(collection_name, default_fields)
254
255
    @pytest.mark.level(2)
256
    def test_create_collection_with_empty_collectionname(self, connect):
257
        collection_name = ''
258
        with pytest.raises(Exception) as e:
259
            connect.create_collection(collection_name, default_fields)
260
261
    @pytest.mark.level(2)
262
    def test_create_collection_with_none_collectionname(self, connect):
263
        collection_name = None
264
        with pytest.raises(Exception) as e:
265
            connect.create_collection(collection_name, default_fields)
266
267
    def test_create_collection_None(self, connect):
268
        '''
269
        target: test create collection but the collection name is None
270
        method: create collection, param collection_name is None
271
        expected: create raise error
272
        '''
273
        with pytest.raises(Exception) as e:
274
            connect.create_collection(None, default_fields)
275
276
    def test_create_collection_no_dimension(self, connect):
277
        '''
278
        target: test create collection with no dimension params
279
        method: create collection with corrent params
280
        expected: create status return ok
281
        '''
282
        collection_name = gen_unique_str(collection_id)
283
        fields = copy.deepcopy(default_fields)
284
        fields["fields"][-1]["params"].pop("dim")
285
        with pytest.raises(Exception) as e:
286
            connect.create_collection(collection_name, fields)
287
288
    def test_create_collection_no_segment_row_count(self, connect):
289
        '''
290
        target: test create collection with no segment_row_count params
291
        method: create collection with corrent params
292
        expected: use default default_segment_row_count
293
        '''
294
        collection_name = gen_unique_str(collection_id)
295
        fields = copy.deepcopy(default_fields)
296
        fields.pop("segment_row_count")
297
        connect.create_collection(collection_name, fields)
298
        res = connect.get_collection_info(collection_name)
299
        logging.getLogger().info(res)
300
        assert res["segment_row_count"] == default_segment_row_count
301
302
    # def _test_create_collection_no_metric_type(self, connect):
303
    #     '''
304
    #     target: test create collection with no metric_type params
305
    #     method: create collection with corrent params
306
    #     expected: use default L2
307
    #     '''
308
    #     collection_name = gen_unique_str(collection_id)
309
    #     fields = copy.deepcopy(default_fields)
310
    #     fields["fields"][-1]["params"].pop("metric_type")
311
    #     connect.create_collection(collection_name, fields)
312
    #     res = connect.get_collection_info(collection_name)
313
    #     logging.getLogger().info(res)
314
    #     assert res["metric_type"] == "L2"
315
316
    # TODO: assert exception
317
    def test_create_collection_limit_fields(self, connect):
318
        collection_name = gen_unique_str(collection_id)
319
        limit_num = 64
320
        fields = copy.deepcopy(default_fields)
321
        for i in range(limit_num):
322
            field_name = gen_unique_str("field_name")
323
            field = {"field": field_name, "type": DataType.INT64}
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DataType does not seem to be defined.
Loading history...
324
            fields["fields"].append(field)
325
        with pytest.raises(Exception) as e:
326
            connect.create_collection(collection_name, fields)
327
328
    # TODO: assert exception
329
    def test_create_collection_invalid_field_name(self, connect, get_invalid_string):
330
        collection_name = gen_unique_str(collection_id)
331
        fields = copy.deepcopy(default_fields)
332
        field_name = get_invalid_string
333
        field = {"field": field_name, "type": DataType.INT64}
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DataType does not seem to be defined.
Loading history...
334
        fields["fields"].append(field)
335
        with pytest.raises(Exception) as e:
336
            connect.create_collection(collection_name, fields)
337
338
    # TODO: assert exception
339
    def test_create_collection_invalid_field_type(self, connect, get_field_type):
340
        collection_name = gen_unique_str(collection_id)
341
        fields = copy.deepcopy(default_fields)
342
        field_type = get_field_type
343
        field = {"field": "test_field", "type": field_type}
344
        fields["fields"].append(field)
345
        with pytest.raises(Exception) as e:
346
            connect.create_collection(collection_name, fields)
347