milvus_benchmark.client   F
last analyzed

Complexity

Total Complexity 80

Size/Duplication

Total Lines 371
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 297
dl 0
loc 371
rs 2
c 0
b 0
f 0
wmc 80

37 Methods

Rating   Name   Duplication   Size   Complexity  
A MilvusClient.list_partitions() 0 4 1
A MilvusClient.check_status() 0 7 2
A MilvusClient.check_result_ids() 0 6 3
A MilvusClient.delete_rand() 0 15 4
A MilvusClient.query_rand() 0 9 1
A MilvusClient.create_collection() 0 12 3
A MilvusClient.create_partition() 0 3 1
A MilvusClient.delete() 0 6 2
A MilvusClient.count() 0 9 3
A MilvusClient.create_index() 0 8 2
A MilvusClient.get_rand_entities() 0 5 1
C MilvusClient.__init__() 0 36 10
A MilvusClient.drop_index() 0 3 1
A MilvusClient.preload_collection() 0 5 1
A MilvusClient.__str__() 0 2 1
A MilvusClient.get_server_config() 0 2 1
A MilvusClient.exists_collection() 0 6 2
A MilvusClient.drop() 0 17 5
A MilvusClient.describe_index() 0 9 3
A MilvusClient.query() 0 6 2
A MilvusClient.insert_rand() 0 10 2
A MilvusClient.drop_partition() 0 3 1
A MilvusClient.describe() 0 3 1
A MilvusClient.get_mem_info() 0 7 1
A MilvusClient.get_server_version() 0 3 1
A MilvusClient.compact() 0 6 2
A MilvusClient.get_entities() 0 5 1
A MilvusClient.get_server_commit() 0 2 1
A MilvusClient.set_collection() 0 2 1
A MilvusClient.show_collections() 0 2 1
A MilvusClient.insert() 0 7 2
A MilvusClient.flush() 0 6 2
A MilvusClient.get_rand_ids_each_segment() 0 12 2
A MilvusClient.cmd() 0 5 1
A MilvusClient.get_server_mode() 0 2 1
A MilvusClient.get_rand_ids() 0 17 5
A MilvusClient.clean_db() 0 5 2

2 Functions

Rating   Name   Duplication   Size   Complexity  
A time_wrapper() 0 11 1
A metric_type_to_str() 0 5 3

How to fix   Complexity   

Complexity

Complex classes like milvus_benchmark.client often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import sys
2
import pdb
3
import random
4
import logging
5
import json
6
import time, datetime
7
from multiprocessing import Process
8
from milvus import Milvus, IndexType, MetricType
9
import utils
10
11
logger = logging.getLogger("milvus_benchmark.client")
12
13
SERVER_HOST_DEFAULT = "127.0.0.1"
14
SERVER_PORT_DEFAULT = 19530
15
INDEX_MAP = {
16
    "flat": IndexType.FLAT,
17
    "ivf_flat": IndexType.IVFLAT,
18
    "ivf_sq8": IndexType.IVF_SQ8,
19
    "nsg": IndexType.RNSG,
20
    "ivf_sq8h": IndexType.IVF_SQ8H,
21
    "ivf_pq": IndexType.IVF_PQ,
22
    "hnsw": IndexType.HNSW,
23
    "annoy": IndexType.ANNOY
24
}
25
26
METRIC_MAP = {
27
    "l2": MetricType.L2,
28
    "ip": MetricType.IP,
29
    "jaccard": MetricType.JACCARD,
30
    "hamming": MetricType.HAMMING,
31
    "sub": MetricType.SUBSTRUCTURE,
32
    "super": MetricType.SUPERSTRUCTURE
33
}
34
35
epsilon = 0.1
36
37
def time_wrapper(func):
38
    """
39
    This decorator prints the execution time for the decorated function.
40
    """
41
    def wrapper(*args, **kwargs):
42
        start = time.time()
43
        result = func(*args, **kwargs)
44
        end = time.time()
45
        logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2)))
46
        return result
47
    return wrapper
48
49
50
def metric_type_to_str(metric_type):
51
    for key, value in METRIC_MAP.items():
52
        if value == metric_type:
53
            return key
54
    raise Exception("metric_type: %s mapping not found" % metric_type)
55
56
57
class MilvusClient(object):
58
    def __init__(self, collection_name=None, host=None, port=None, timeout=60):
59
        """
60
        Milvus client wrapper for python-sdk.
61
62
        Default timeout set 60s
63
        """
64
        self._collection_name = collection_name
65
        try:
66
            start_time = time.time()
67
            if not host:
68
                host = SERVER_HOST_DEFAULT
69
            if not port:
70
                port = SERVER_PORT_DEFAULT
71
            logger.debug(host)
72
            logger.debug(port)
73
            # retry connect for remote server
74
            i = 0
75
            while time.time() < start_time + timeout:
76
                try:
77
                    self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False)
78
                    if self._milvus.server_status():
79
                        logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2)))
80
                        break
81
                except Exception as e:
82
                    logger.debug("Milvus connect failed: %d times" % i)
83
                    i = i + 1
84
85
            if time.time() > start_time + timeout:
86
                raise Exception("Server connect timeout")
87
88
        except Exception as e:
89
            raise e
90
        self._metric_type = None
91
        if self._collection_name and self.exists_collection():
92
            self._metric_type = metric_type_to_str(self.describe()[1].metric_type)
93
            self._dimension = self.describe()[1].dimension
94
95
    def __str__(self):
96
        return 'Milvus collection %s' % self._collection_name
97
98
    def set_collection(self, name):
99
        self._collection_name = name
100
101
    def check_status(self, status):
102
        if not status.OK():
103
            logger.error(self._collection_name)
104
            logger.error(status.message)
105
            logger.error(self._milvus.server_status())
106
            logger.error(self.count())
107
            raise Exception("Status not ok")
108
109
    def check_result_ids(self, result):
110
        for index, item in enumerate(result):
111
            if item[0].distance >= epsilon:
112
                logger.error(index)
113
                logger.error(item[0].distance)
114
                raise Exception("Distance wrong")
115
116
    def create_collection(self, collection_name, dimension, index_file_size, metric_type):
117
        if not self._collection_name:
118
            self._collection_name = collection_name
119
        if metric_type not in METRIC_MAP.keys():
120
            raise Exception("Not supported metric_type: %s" % metric_type)
121
        metric_type = METRIC_MAP[metric_type]
122
        create_param = {'collection_name': collection_name,
123
                 'dimension': dimension,
124
                 'index_file_size': index_file_size, 
125
                 "metric_type": metric_type}
126
        status = self._milvus.create_collection(create_param)
127
        self.check_status(status)
128
129
    def create_partition(self, tag_name):
130
        status = self._milvus.create_partition(self._collection_name, tag_name)
131
        self.check_status(status)
132
133
    def drop_partition(self, tag_name):
134
        status = self._milvus.drop_partition(self._collection_name, tag_name)
135
        self.check_status(status)
136
137
    def list_partitions(self):
138
        status, tags = self._milvus.list_partitions(self._collection_name)
139
        self.check_status(status)
140
        return tags
141
142
    @time_wrapper
143
    def insert(self, X, ids=None, collection_name=None):
144
        if collection_name is None:
145
            collection_name = self._collection_name
146
        status, result = self._milvus.insert(collection_name, X, ids)
147
        self.check_status(status)
148
        return status, result
149
150
    def insert_rand(self):
151
        insert_xb = random.randint(1, 100)
152
        X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)]
153
        X = utils.normalize(self._metric_type, X)
154
        count_before = self.count()
155
        status, _ = self.insert(X)
156
        self.check_status(status)
157
        self.flush()
158
        if count_before + insert_xb != self.count():
159
            raise Exception("Assert failed after inserting")
160
161
    def get_rand_ids(self, length):
162
        while True:
163
            status, stats = self._milvus.get_collection_stats(self._collection_name)
164
            self.check_status(status)
165
            segments = stats["partitions"][0]["segments"]
166
            # random choice one segment
167
            segment = random.choice(segments)
168
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
169
            if not status.OK():
170
                logger.error(status.message)
171
                continue
172
            if len(segment_ids):
173
                break
174
        if length >= len(segment_ids):
175
            logger.debug("Reset length: %d" % len(segment_ids))
176
            return segment_ids
177
        return random.sample(segment_ids, length)
178
179
    def get_rand_ids_each_segment(self, length):
180
        res = []
181
        status, stats = self._milvus.get_collection_stats(self._collection_name)
182
        self.check_status(status)
183
        segments = stats["partitions"][0]["segments"]
184
        segments_num = len(segments)
185
        # random choice from each segment
186
        for segment in segments:
187
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
188
            self.check_status(status)
189
            res.extend(segment_ids[:length])
190
        return segments_num, res
191
192
    def get_rand_entities(self, length):
193
        ids = self.get_rand_ids(length)
194
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
195
        self.check_status(status)
196
        return ids, get_res
197
198
    @time_wrapper
199
    def get_entities(self, get_ids):
200
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
201
        self.check_status(status)
202
        return get_res
203
204
    @time_wrapper
205
    def delete(self, ids, collection_name=None):
206
        if collection_name is None:
207
            collection_name = self._collection_name
208
        status = self._milvus.delete_entity_by_id(collection_name, ids)
209
        self.check_status(status)
210
211
    def delete_rand(self):
212
        delete_id_length = random.randint(1, 100)
213
        count_before = self.count()
214
        logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length))
215
        delete_ids = self.get_rand_ids(delete_id_length)
216
        self.delete(delete_ids)
217
        self.flush()
218
        logger.info("%s: count after delete: %d" % (self._collection_name, self.count()))
219
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids)
220
        self.check_status(status)
221
        for item in get_res:
222
            if item:
223
                raise Exception("Assert failed after delete")
224
        if count_before - len(delete_ids) != self.count():
225
            raise Exception("Assert failed after delete")
226
227
    @time_wrapper
228
    def flush(self, collection_name=None):
229
        if collection_name is None:
230
            collection_name = self._collection_name
231
        status = self._milvus.flush([collection_name])
232
        self.check_status(status)
233
234
    @time_wrapper
235
    def compact(self, collection_name=None):
236
        if collection_name is None:
237
            collection_name = self._collection_name
238
        status = self._milvus.compact(collection_name)
239
        self.check_status(status)
240
241
    @time_wrapper
242
    def create_index(self, index_type, index_param=None):
243
        index_type = INDEX_MAP[index_type]
244
        logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type))
245
        if index_param:
246
            logger.info(index_param)
247
        status = self._milvus.create_index(self._collection_name, index_type, index_param)
248
        self.check_status(status)
249
250
    def describe_index(self):
251
        status, result = self._milvus.get_index_info(self._collection_name)
252
        self.check_status(status)
253
        index_type = None
254
        for k, v in INDEX_MAP.items():
255
            if result._index_type == v:
256
                index_type = k
257
                break
258
        return {"index_type": index_type, "index_param": result._params}
259
260
    def drop_index(self):
261
        logger.info("Drop index: %s" % self._collection_name)
262
        return self._milvus.drop_index(self._collection_name)
263
264
    def query(self, X, top_k, search_param=None, collection_name=None):
265
        if collection_name is None:
266
            collection_name = self._collection_name
267
        status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param)
268
        self.check_status(status)
269
        return result
270
271
    def query_rand(self):
272
        top_k = random.randint(1, 100)
273
        nq = random.randint(1, 100)
274
        nprobe = random.randint(1, 100)
275
        search_param = {"nprobe": nprobe}
276
        _, X = self.get_rand_entities(nq)
277
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
278
        status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param)
279
        self.check_status(status)
280
        # for i, item in enumerate(search_res):
281
        #     if item[0].id != ids[i]:
282
        #         logger.warning("The index of search result: %d" % i)
283
        #         raise Exception("Query failed")
284
285
    # @time_wrapper
286
    # def query_ids(self, top_k, ids, search_param=None):
287
    #     status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param)
288
    #     self.check_result_ids(result)
289
    #     return result
290
291
    def count(self, name=None):
292
        if name is None:
293
            name = self._collection_name
294
        logger.debug(self._milvus.count_entities(name))
295
        row_count = self._milvus.count_entities(name)[1]
296
        if not row_count:
297
            row_count = 0
298
        logger.debug("Row count: %d in collection: <%s>" % (row_count, name))
299
        return row_count
300
301
    def drop(self, timeout=120, name=None):
302
        timeout = int(timeout)
303
        if name is None:
304
            name = self._collection_name
305
        logger.info("Start delete collection: %s" % name)
306
        status = self._milvus.drop_collection(name)
307
        self.check_status(status)
308
        i = 0
309
        while i < timeout:
310
            if self.count(name=name):
311
                time.sleep(1)
312
                i = i + 1
313
                continue
314
            else:
315
                break
316
        if i >= timeout:
317
            logger.error("Delete collection timeout")
318
319
    def describe(self):
320
        # logger.info(self._milvus.get_collection_info(self._collection_name))
321
        return self._milvus.get_collection_info(self._collection_name)
322
323
    def show_collections(self):
324
        return self._milvus.list_collections()
325
326
    def exists_collection(self, collection_name=None):
327
        if collection_name is None:
328
            collection_name = self._collection_name
329
        _, res = self._milvus.has_collection(collection_name)
330
        # self.check_status(status)
331
        return res
332
333
    def clean_db(self):
334
        collection_names = self.show_collections()[1]
335
        for name in collection_names:
336
            logger.debug(name)
337
            self.drop(name=name)
338
339
    @time_wrapper
340
    def preload_collection(self):
341
        status = self._milvus.load_collection(self._collection_name, timeout=3000)
342
        self.check_status(status)
343
        return status
344
345
    def get_server_version(self):
346
        _, res = self._milvus.server_version()
347
        return res
348
349
    def get_server_mode(self):
350
        return self.cmd("mode")
351
352
    def get_server_commit(self):
353
        return self.cmd("build_commit_id")
354
355
    def get_server_config(self):
356
        return json.loads(self.cmd("get_config *"))
357
358
    def get_mem_info(self):
359
        result = json.loads(self.cmd("get_system_info"))
360
        result_human = {
361
            # unit: Gb
362
            "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2)
363
        }
364
        return result_human
365
366
    def cmd(self, command):
367
        status, res = self._milvus._cmd(command)
368
        logger.info("Server command: %s, result: %s" % (command, res))
369
        self.check_status(status)
370
        return res
371