Passed
Push — master ( fd4969...54df52 )
by
unknown
01:50
created

TestGetInvalid.get_collection_name()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
import time
2
import random
3
import pdb
4
import copy
5
import threading
6
import logging
7
from multiprocessing import Pool, Process
8
import concurrent.futures
9
import pytest
10
from utils import *
11
12
13
dim = 128
14
segment_size = 10
15
collection_id = "test_get"
16
DELETE_TIMEOUT = 60
17
tag = "1970-01-01"
18
nb = 6000
19
field_name = "float_entity"
20
default_index_name = "insert_index"
21
entity = gen_entities(1)
22
binary_entity = gen_binary_entities(1)
23
entities = gen_entities(nb)
24
raw_vectors, binary_entities = gen_binary_entities(nb)
25
default_single_query = {
26
    "bool": {
27
        "must": [
28
            {"vector": {field_name: {"topk": 10, "query": gen_vectors(1, dim), "params": {"nprobe": 10}}}}
29
        ]
30
    }
31
}
32
33
class TestGetBase:
34
    """
35
    ******************************************************************
36
      The following cases are used to test `get_entity_by_id` function
37
    ******************************************************************
38
    """
39
    @pytest.fixture(
40
        scope="function",
41
        params=gen_simple_index()
42
    )
43
    def get_simple_index(self, request, connect):
44
        if str(connect._cmd("mode")) == "CPU":
45
            if request.param["index_type"] in index_cpu_not_support():
46
                pytest.skip("sq8h not support in CPU mode")
47
        return request.param
48
49
    @pytest.fixture(
50
        scope="function",
51
        params=[
52
            1,
53
            10,
54
            100,
55
            500
56
        ],
57
    )
58
    def get_pos(self, request):
59
        yield request.param
60
61
    def test_get_entity(self, connect, collection, get_pos):
62
        '''
63
        target: test.get_entity_by_id, get one
64
        method: add entity, and get
65
        expected: entity returned equals insert
66
        '''
67
        ids = connect.insert(collection, entities)
68
        connect.flush([collection])
69
        res_count = connect.count_entities(collection)
70
        assert res_count == nb
71
        get_ids = [ids[get_pos]]
72
        res = connect.get_entity_by_id(collection, get_ids)
73
        assert_equal_vector(res[0].get("vector"), entities[-1]["values"][get_pos])
74
75
    def test_get_entity_multi_ids(self, connect, collection, get_pos):
76
        '''
77
        target: test.get_entity_by_id, get one
78
        method: add entity, and get
79
        expected: entity returned equals insert
80
        '''
81
        ids = connect.insert(collection, entities)
82
        connect.flush([collection])
83
        get_ids = ids[:get_pos]
84
        res = connect.get_entity_by_id(collection, get_ids)
85
        for i in range(get_pos):
86
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
87
88
    def test_get_entity_parts_ids(self, connect, collection):
89
        '''
90
        target: test.get_entity_by_id, some ids in collection, some ids not
91
        method: add entity, and get
92
        expected: entity returned equals insert
93
        '''
94
        ids = connect.insert(collection, entities)
95
        connect.flush([collection])
96
        get_ids = [ids[0], 1, ids[-1]]
97
        res = connect.get_entity_by_id(collection, get_ids)
98
        assert_equal_vector(res[0].get("vector"), entities[-1]["values"][0])
99
        assert_equal_vector(res[-1].get("vector"), entities[-1]["values"][-1])
100
        assert res[1] is None
101
102
    def test_get_entity_limit(self, connect, collection, args):
103
        '''
104
        target: test.get_entity_by_id
105
        method: add entity, and get, limit > 1000
106
        expected: entity returned
107
        '''
108
        if args["handler"] == "HTTP":
109
            pytest.skip("skip in http mode")
110
111
        ids = connect.insert(collection, entities)
112
        connect.flush([collection])
113
        with pytest.raises(Exception) as e:
114
            res = connect.get_entity_by_id(collection, ids)
115
116 View Code Duplication
    def test_get_entity_same_ids(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
117
        '''
118
        target: test.get_entity_by_id, with the same ids
119
        method: add entity, and get one id
120
        expected: entity returned equals insert
121
        '''
122
        ids = [1 for i in range(nb)]
123
        res_ids = connect.insert(collection, entities, ids)
124
        connect.flush([collection])
125
        get_ids = [ids[0]]
126
        res = connect.get_entity_by_id(collection, get_ids)
127
        assert len(res) == 1
128
        assert_equal_vector(res[0].get("vector"), entities[-1]["values"][0])
129
130
    def test_get_entity_params_same_ids(self, connect, collection):
131
        '''
132
        target: test.get_entity_by_id, with the same ids
133
        method: add entity, and get entity with the same ids
134
        expected: entity returned equals insert
135
        '''
136
        ids = [1]
137
        res_ids = connect.insert(collection, entity, ids)
138
        connect.flush([collection])
139
        get_ids = [1, 1]
140
        res = connect.get_entity_by_id(collection, get_ids)
141
        assert len(res) == len(get_ids)
142
        for i in range(len(get_ids)):
143
            logging.getLogger().info(i)
144
            assert_equal_vector(res[i].get("vector"), entity[-1]["values"][0])
145
146 View Code Duplication
    def test_get_entities_params_same_ids(self, connect, collection):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
147
        '''
148
        target: test.get_entity_by_id, with the same ids
149
        method: add entities, and get entity with the same ids
150
        expected: entity returned equals insert
151
        '''
152
        res_ids = connect.insert(collection, entities)
153
        connect.flush([collection])
154
        get_ids = [res_ids[0], res_ids[0]]
155
        res = connect.get_entity_by_id(collection, get_ids)
156
        assert len(res) == len(get_ids)
157
        for i in range(len(get_ids)):
158
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][0])
159
160
    """
161
    ******************************************************************
162
      The following cases are used to test `get_entity_by_id` function, with different metric type
163
    ******************************************************************
164
    """
165
166
    def test_get_entity_parts_ids_ip(self, connect, ip_collection):
167
        '''
168
        target: test.get_entity_by_id, some ids in ip_collection, some ids not
169
        method: add entity, and get
170
        expected: entity returned equals insert
171
        '''
172
        ids = connect.insert(ip_collection, entities)
173
        connect.flush([ip_collection])
174
        get_ids = [ids[0], 1, ids[-1]]
175
        res = connect.get_entity_by_id(ip_collection, get_ids)
176
        assert_equal_vector(res[0].get("vector"), entities[-1]["values"][0])
177
        assert_equal_vector(res[-1].get("vector"), entities[-1]["values"][-1])
178
        assert res[1] is None
179
180
    def test_get_entity_parts_ids_jac(self, connect, jac_collection):
181
        '''
182
        target: test.get_entity_by_id, some ids in jac_collection, some ids not
183
        method: add entity, and get
184
        expected: entity returned equals insert
185
        '''
186
        ids = connect.insert(jac_collection, binary_entities)
187
        connect.flush([jac_collection])
188
        get_ids = [ids[0], 1, ids[-1]]
189
        res = connect.get_entity_by_id(jac_collection, get_ids)
190
        assert_equal_vector(res[0].get("binary_vector"), binary_entities[-1]["values"][0])
191
        assert_equal_vector(res[-1].get("binary_vector"), binary_entities[-1]["values"][-1])
192
        assert res[1] is None
193
194
    """
195
    ******************************************************************
196
      The following cases are used to test `get_entity_by_id` function, with tags
197
    ******************************************************************
198
    """
199
    def test_get_entities_tag(self, connect, collection, get_pos):
200
        '''
201
        target: test.get_entity_by_id
202
        method: add entities with tag, get
203
        expected: entity returned
204
        '''
205
        connect.create_partition(collection, tag)
206
        ids = connect.insert(collection, entities, partition_tag=tag)
207
        connect.flush([collection])
208
        get_ids = ids[:get_pos]
209
        res = connect.get_entity_by_id(collection, get_ids)
210
        for i in range(get_pos):
211
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
212
213
    def test_get_entities_tag_default(self, connect, collection, get_pos):
214
        '''
215
        target: test.get_entity_by_id
216
        method: add entities with default tag, get
217
        expected: entity returned
218
        '''
219
        connect.create_partition(collection, tag)
220
        ids = connect.insert(collection, entities)
221
        connect.flush([collection])
222
        get_ids = ids[:get_pos]
223
        res = connect.get_entity_by_id(collection, get_ids)
224
        for i in range(get_pos):
225
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
226
227 View Code Duplication
    def test_get_entities_tags_default(self, connect, collection, get_pos):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
228
        '''
229
        target: test.get_entity_by_id
230
        method: create partitions, add entities with default tag, get
231
        expected: entity returned
232
        '''
233
        tag_new = "tag_new"
234
        connect.create_partition(collection, tag)
235
        connect.create_partition(collection, tag_new)
236
        ids = connect.insert(collection, entities)
237
        connect.flush([collection])
238
        get_ids = ids[:get_pos]
239
        res = connect.get_entity_by_id(collection, get_ids)
240
        for i in range(get_pos):
241
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
242
243 View Code Duplication
    def test_get_entities_tags_A(self, connect, collection, get_pos):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
244
        '''
245
        target: test.get_entity_by_id
246
        method: create partitions, add entities with default tag, get
247
        expected: entity returned
248
        '''
249
        tag_new = "tag_new"
250
        connect.create_partition(collection, tag)
251
        connect.create_partition(collection, tag_new)
252
        ids = connect.insert(collection, entities, partition_tag=tag)
253
        connect.flush([collection])
254
        get_ids = ids[:get_pos]
255
        res = connect.get_entity_by_id(collection, get_ids)
256
        for i in range(get_pos):
257
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
258
259
    def test_get_entities_tags_B(self, connect, collection, get_pos):
260
        '''
261
        target: test.get_entity_by_id
262
        method: create partitions, add entities with default tag, get
263
        expected: entity returned
264
        '''
265
        tag_new = "tag_new"
266
        connect.create_partition(collection, tag)
267
        connect.create_partition(collection, tag_new)
268
        new_entities = gen_entities(nb+1)
269
        ids = connect.insert(collection, entities, partition_tag=tag)
270
        ids_new = connect.insert(collection, new_entities, partition_tag=tag_new)
271
        connect.flush([collection])
272
        get_ids = ids[:get_pos]
273
        get_ids.extend(ids_new[:get_pos])
274
        res = connect.get_entity_by_id(collection, get_ids)
275
        for i in range(get_pos):
276
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
277
        for i in range(get_pos, get_pos*2):
278
            assert_equal_vector(res[i].get("vector"), new_entities[-1]["values"][i-get_pos])
279
280 View Code Duplication
    def test_get_entities_indexed_tag(self, connect, collection, get_simple_index, get_pos):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
281
        '''
282
        target: test.get_entity_by_id
283
        method: add entities with tag, get
284
        expected: entity returned
285
        '''
286
        connect.create_partition(collection, tag)
287
        ids = connect.insert(collection, entities, partition_tag=tag)
288
        connect.flush([collection])
289
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
290
        get_ids = ids[:get_pos]
291
        res = connect.get_entity_by_id(collection, get_ids)
292
        for i in range(get_pos):
293
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
294
295
    """
296
    ******************************************************************
297
      The following cases are used to test `get_entity_by_id` function, with fields params
298
    ******************************************************************
299
    """
300
    # TODO: 
301
    def test_get_entity_field(self, connect, collection, get_pos):
302
        '''
303
        target: test.get_entity_by_id, get one
304
        method: add entity, and get
305
        expected: entity returned equals insert
306
        '''
307
        ids = connect.insert(collection, entities)
308
        connect.flush([collection])
309
        get_ids = [ids[get_pos]]
310
        fields = ["int8"]
311
        res = connect.get_entity_by_id(collection, get_ids, fields = fields)
312
        # assert fields
313
314
    # TODO: 
315
    def test_get_entity_fields(self, connect, collection, get_pos):
316
        '''
317
        target: test.get_entity_by_id, get one
318
        method: add entity, and get
319
        expected: entity returned equals insert
320
        '''
321
        ids = connect.insert(collection, entities)
322
        connect.flush([collection])
323
        get_ids = [ids[get_pos]]
324
        fields = ["int8", "int64", "float", "vector"]
325
        res = connect.get_entity_by_id(collection, get_ids, fields = fields)
326
        # assert fields
327
328
    # TODO: assert exception
329 View Code Duplication
    def test_get_entity_field_not_match(self, connect, collection, get_pos):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
330
        '''
331
        target: test.get_entity_by_id, get one
332
        method: add entity, and get
333
        expected: entity returned equals insert
334
        '''
335
        ids = connect.insert(collection, entities)
336
        connect.flush([collection])
337
        get_ids = [ids[get_pos]]
338
        fields = ["int1288"]
339
        with pytest.raises(Exception) as e:
340
            res = connect.get_entity_by_id(collection, get_ids, fields = fields)
341
342
    # TODO: assert exception
343 View Code Duplication
    def test_get_entity_fields_not_match(self, connect, collection, get_pos):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
344
        '''
345
        target: test.get_entity_by_id, get one
346
        method: add entity, and get
347
        expected: entity returned equals insert
348
        '''
349
        ids = connect.insert(collection, entities)
350
        connect.flush([collection])
351
        get_ids = [ids[get_pos]]
352
        fields = ["int1288", "int8"]
353
        with pytest.raises(Exception) as e:
354
            res = connect.get_entity_by_id(collection, get_ids, fields = fields)
355
356
    def test_get_entity_id_not_exised(self, connect, collection):
357
        '''
358
        target: test get entity, params entity_id not existed
359
        method: add entity and get 
360
        expected: empty result
361
        '''
362
        ids = connect.insert(collection, entity)
363
        connect.flush([collection])
364
        res = connect.get_entity_by_id(collection, [1]) 
365
        assert res[0] is None
366
367
    def test_get_entity_collection_not_existed(self, connect, collection):
368
        '''
369
        target: test get entity, params collection_name not existed
370
        method: add entity and get
371
        expected: error raised
372
        '''
373
        ids = connect.insert(collection, entity)
374
        connect.flush([collection])
375
        collection_new = gen_unique_str()
376
        with pytest.raises(Exception) as e:
377
            res = connect.get_entity_by_id(collection_new, [ids[0]])
378
379
    """
380
    ******************************************************************
381
      The following cases are used to test `get_entity_by_id` function, after deleted
382
    ******************************************************************
383
    """
384
    def test_get_entity_after_delete(self, connect, collection, get_pos):
385
        '''
386
        target: test.get_entity_by_id
387
        method: add entities, and delete, get entity by the given id
388
        expected: empty result
389
        '''
390
        ids = connect.insert(collection, entities)
391
        connect.flush([collection])
392
        delete_ids = [ids[get_pos]]
393
        status = connect.delete_entity_by_id(collection, delete_ids)
394
        connect.flush([collection])
395
        get_ids = [ids[get_pos]]
396
        res = connect.get_entity_by_id(collection, get_ids)
397
        assert res[0] is None
398
399
    # TODO
400
    def test_get_entities_after_delete(self, connect, collection, get_pos):
401
        '''
402
        target: test.get_entity_by_id
403
        method: add entities, and delete, get entity by the given id
404
        expected: empty result
405
        '''
406
        ids = connect.insert(collection, entities)
407
        connect.flush([collection])
408
        delete_ids = ids[:get_pos]
409
        status = connect.delete_entity_by_id(collection, delete_ids)
410
        connect.flush([collection])
411
        get_ids = delete_ids
412
        res = connect.get_entity_by_id(collection, get_ids)
413
        for i in range(get_pos):
414
            assert res[i] is None
415
416
    def test_get_entities_after_delete_compact(self, connect, collection, get_pos):
417
        '''
418
        target: test.get_entity_by_id
419
        method: add entities, and delete, get entity by the given id
420
        expected: empty result
421
        '''
422
        ids = connect.insert(collection, entities)
423
        connect.flush([collection])
424
        delete_ids = ids[:get_pos]
425
        status = connect.delete_entity_by_id(collection, delete_ids)
426
        connect.flush([collection])
427
        connect.compact(collection)
428
        get_ids = ids[:get_pos]
429
        res = connect.get_entity_by_id(collection, get_ids)
430
        for i in range(get_pos):
431
            assert res[i] is None
432
433
    def test_get_entities_indexed_batch(self, connect, collection, get_simple_index, get_pos):
434
        '''
435
        target: test.get_entity_by_id
436
        method: add entities batch, create index, get
437
        expected: entity returned
438
        '''
439
        ids = connect.insert(collection, entities)
440
        connect.flush([collection])
441
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
442
        get_ids = ids[:get_pos]
443
        res = connect.get_entity_by_id(collection, get_ids)
444
        for i in range(get_pos):
445
            assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
446
447 View Code Duplication
    def test_get_entities_indexed_single(self, connect, collection, get_simple_index, get_pos):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
448
        '''
449
        target: test.get_entity_by_id
450
        method: add entities 1 entity/per request, create index, get
451
        expected: entity returned
452
        '''
453
        ids = []
454
        for i in range(nb):
455
            ids.append(connect.insert(collection, entity)[0])
456
        connect.flush([collection])
457
        connect.create_index(collection, field_name, default_index_name, get_simple_index)
458
        get_ids = ids[:get_pos]
459
        res = connect.get_entity_by_id(collection, get_ids)
460
        for i in range(get_pos):
461
            assert_equal_vector(res[i].get("vector"), entity[-1]["values"][0])
462
463
    def test_get_entities_after_delete_disable_autoflush(self, connect, collection, get_pos):
464
        '''
465
        target: test.get_entity_by_id
466
        method: disable autoflush, add entities, and delete, get entity by the given id
467
        expected: empty result
468
        '''
469
        ids = connect.insert(collection, entities)
470
        connect.flush([collection])
471
        delete_ids = ids[:get_pos]
472
        try:
473
            disable_flush(connect)
474
            status = connect.delete_entity_by_id(collection, delete_ids)
475
            get_ids = ids[:get_pos]
476
            res = connect.get_entity_by_id(collection, get_ids)
477
            for i in range(get_pos):
478
                assert_equal_vector(res[i].get("vector"), entities[-1]["values"][i])
479
        finally:
480
            enable_flush(connect)
481
482
    # TODO:
483
    def test_get_entities_after_delete_same_ids(self, connect, collection):
484
        '''
485
        target: test.get_entity_by_id
486
        method: add entities with the same ids, and delete, get entity by the given id
487
        expected: empty result
488
        '''
489
        ids = [i for i in range(nb)]
490
        ids[0] = 1
491
        res_ids = connect.insert(collection, entities, ids)
492
        connect.flush([collection])
493
        status = connect.delete_entity_by_id(collection, [1])
494
        connect.flush([collection])
495
        get_ids = [1]
496
        res = connect.get_entity_by_id(collection, get_ids)
497
        assert res[0] is None
498
499
    def test_get_entity_after_delete_with_partition(self, connect, collection, get_pos):
500
        '''
501
        target: test.get_entity_by_id
502
        method: add entities into partition, and delete, get entity by the given id
503
        expected: get one entity
504
        '''
505
        connect.create_partition(collection, tag)
506
        ids = connect.insert(collection, entities, partition_tag=tag)
507
        connect.flush([collection])
508
        status = connect.delete_entity_by_id(collection, [ids[get_pos]])
509
        connect.flush([collection])
510
        res = connect.get_entity_by_id(collection, [ids[get_pos]])
511
        assert res[0] is None
512
513
    def test_get_entity_by_id_multithreads(self, connect, collection):
514
        ids = connect.insert(collection, entities)
515
        connect.flush([collection])
516
        get_id = ids[100:200]
517
        def get():
518
            res = connect.get_entity_by_id(collection, get_id)
519
            assert len(res) == len(get_id)
520
            for i in range(len(res)):
521
                assert_equal_vector(res[i].get("vector"), entities[-1]["values"][100+i])
522
        with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
523
            future_results = {executor.submit(
524
                get): i for i in range(10)}
525
            for future in concurrent.futures.as_completed(future_results):
526
                future.result()
527
528
529
class TestGetInvalid(object):
530
    """
531
    Test get entities with invalid params
532
    """
533
    @pytest.fixture(
534
        scope="function",
535
        params=gen_invalid_strs()
536
    )
537
    def get_collection_name(self, request):
538
        yield request.param
539
540
    @pytest.fixture(
541
        scope="function",
542
        params=gen_invalid_strs()
543
    )
544
    def get_field_name(self, request):
545
        yield request.param
546
547
    @pytest.fixture(
548
        scope="function",
549
        params=gen_invalid_ints()
550
    )
551
    def get_entity_id(self, request):
552
        yield request.param
553
554
    @pytest.mark.level(2)
555
    def test_insert_ids_invalid(self, connect, collection, get_entity_id):
556
        '''
557
        target: test insert, with using customize ids, which are not int64
558
        method: create collection and insert entities in it
559
        expected: raise an exception
560
        '''
561
        entity_id = get_entity_id
562
        ids = [entity_id for _ in range(nb)]
563
        with pytest.raises(Exception):
564
            connect.get_entity_by_id(collection, ids)
565
566
    @pytest.mark.level(2)
567
    def test_insert_parts_ids_invalid(self, connect, collection, get_entity_id):
568
        '''
569
        target: test insert, with using customize ids, which are not int64
570
        method: create collection and insert entities in it
571
        expected: raise an exception
572
        '''
573
        entity_id = get_entity_id
574
        ids = [i for i in range(nb)]
575
        ids[-1] = entity_id
576
        with pytest.raises(Exception):
577
            connect.get_entity_by_id(collection, ids)
578
579
    @pytest.mark.level(2)
580
    def test_get_entities_with_invalid_collection_name(self, connect, get_collection_name):
581
        collection_name = get_collection_name
582
        ids = [1]
583
        with pytest.raises(Exception):
584
            res = connect.get_entity_by_id(collection_name, ids)
585
586
    @pytest.mark.level(2)
587
    def test_get_entities_with_invalid_field_name(self, connect, collection, get_field_name):
588
        field_name = get_field_name
589
        ids = [1]
590
        fields = [field_name]
591
        with pytest.raises(Exception):
592
            res = connect.get_entity_by_id(collection, ids, fields=fields)