Passed
Push — master ( fd4969...54df52 )
by
unknown
01:50
created

test_collection_logic.create_index()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import pdb
2
import pytest
3
import logging
4
import itertools
5
from time import sleep
6
from multiprocessing import Process
7
from milvus import IndexType, MetricType
8
from utils import *
9
10
dim = 128
11
default_segment_size = 1024
12
drop_collection_interval_time = 3
13
segment_size = 10
14
collection_id = "logic"
15
vectors = gen_vectors(100, dim)
16
default_fields = gen_default_fields() 
17
18
19
def create_collection(connect, **params):
20
    connect.create_collection(params["collection_name"], default_fields)
21
22
def search_collection(connect, **params):
23
    status, result = connect.search(
24
        params["collection_name"], 
25
        params["top_k"], 
26
        params["query_vectors"],
27
        params={"nprobe": params["nprobe"]})
28
    return status
29
30
def load_collection(connect, **params):
31
    connect.load_collection(params["collection_name"])
32
33
def has(connect, **params):
34
    status, result = connect.has_collection(params["collection_name"])
35
    return status
36
37
def show(connect, **params):
38
    status, result = connect.list_collections()
39
    return status
40
41
def delete(connect, **params):
42
    status = connect.drop_collection(params["collection_name"])
43
    return status
44
45
def describe(connect, **params):
46
    status, result = connect.get_collection_info(params["collection_name"])
47
    return status
48
49
def rowcount(connect, **params):
50
    status, result = connect.count_entities(params["collection_name"])
51
    return status
52
53
def create_index(connect, **params):
54
    status = connect.create_index(params["collection_name"], params["index_type"], params["index_param"])
55
    return status
56
57
func_map = { 
58
    # 0:has, 
59
    1:show,
60
    10:create_collection, 
61
    11:describe,
62
    12:rowcount,
63
    13:search_collection,
64
    14:load_collection,
65
    15:create_index,
66
    30:delete
67
}
68
69
def gen_sequence():
70
    raw_seq = func_map.keys()
71
    result = itertools.permutations(raw_seq)
72
    for x in result:
73
        yield x
74
75
76
class TestCollectionLogic(object):
77
    @pytest.mark.parametrize("logic_seq", gen_sequence())
78
    @pytest.mark.level(2)
79
    def _test_logic(self, connect, logic_seq, args):
80
        if args["handler"] == "HTTP":
81
            pytest.skip("Skip in http mode")
82
        if self.is_right(logic_seq):
83
            self.execute(logic_seq, connect)
84
        else:
85
            self.execute_with_error(logic_seq, connect)
86
        self.tear_down(connect)
87
88
    def is_right(self, seq):
89
        if sorted(seq) == seq:
90
            return True
91
92
        not_created = True
93
        has_deleted = False
94
        for i in range(len(seq)):
95
            if seq[i] > 10 and not_created:
96
                return False
97
            elif seq [i] > 10 and has_deleted:
98
                return False
99
            elif seq[i] == 10:
100
                not_created = False
101
            elif seq[i] == 30:
102
                has_deleted = True
103
104
        return True
105
106
    def execute(self, logic_seq, connect):
107
        basic_params = self.gen_params()
108
        for i in range(len(logic_seq)):
109
            # logging.getLogger().info(logic_seq[i])
110
            f = func_map[logic_seq[i]]
111
            status = f(connect, **basic_params)
112
            assert status.OK()
113
114
    def execute_with_error(self, logic_seq, connect):
115
        basic_params = self.gen_params()
116
117
        error_flag = False
118
        for i in range(len(logic_seq)):
119
            f = func_map[logic_seq[i]]
120
            status = f(connect, **basic_params)
121
            if not status.OK():
122
                # logging.getLogger().info(logic_seq[i])
123
                error_flag = True
124
                break
125
        assert error_flag == True
126
127
    def tear_down(self, connect):
128
        names = connect.list_collections()[1]
129
        for name in names:
130
            connect.drop_collection(name)
131
132
    def gen_params(self):
133
        collection_name = gen_unique_str("collection_id")
134
        top_k = 1
135
        vectors = gen_vectors(2, dim)
136
        param = {'collection_name': collection_name,
137
                 'dimension': dim,
138
                 'metric_type': "L2",
139
                 'nprobe': 1,
140
                 'top_k': top_k,
141
                 'index_type': "IVF_SQ8",
142
                 'index_param': {
143
                        'nlist': 16384
144
                 },
145
                 'query_vectors': vectors}
146
        return param
147