1
|
|
|
import sys |
2
|
|
|
import pdb |
3
|
|
|
import random |
4
|
|
|
import logging |
5
|
|
|
import json |
6
|
|
|
import time, datetime |
7
|
|
|
from multiprocessing import Process |
8
|
|
|
from milvus import Milvus, IndexType, MetricType |
9
|
|
|
import utils |
10
|
|
|
|
11
|
|
|
logger = logging.getLogger("milvus_benchmark.client") |
12
|
|
|
|
13
|
|
|
SERVER_HOST_DEFAULT = "127.0.0.1" |
14
|
|
|
SERVER_PORT_DEFAULT = 19530 |
15
|
|
|
INDEX_MAP = { |
16
|
|
|
"flat": IndexType.FLAT, |
17
|
|
|
"ivf_flat": IndexType.IVFLAT, |
18
|
|
|
"ivf_sq8": IndexType.IVF_SQ8, |
19
|
|
|
"nsg": IndexType.RNSG, |
20
|
|
|
"ivf_sq8h": IndexType.IVF_SQ8H, |
21
|
|
|
"ivf_pq": IndexType.IVF_PQ, |
22
|
|
|
"hnsw": IndexType.HNSW, |
23
|
|
|
"annoy": IndexType.ANNOY |
24
|
|
|
} |
25
|
|
|
|
26
|
|
|
METRIC_MAP = { |
27
|
|
|
"l2": MetricType.L2, |
28
|
|
|
"ip": MetricType.IP, |
29
|
|
|
"jaccard": MetricType.JACCARD, |
30
|
|
|
"hamming": MetricType.HAMMING, |
31
|
|
|
"sub": MetricType.SUBSTRUCTURE, |
32
|
|
|
"super": MetricType.SUPERSTRUCTURE |
33
|
|
|
} |
34
|
|
|
|
35
|
|
|
epsilon = 0.1 |
36
|
|
|
|
37
|
|
|
def time_wrapper(func): |
38
|
|
|
""" |
39
|
|
|
This decorator prints the execution time for the decorated function. |
40
|
|
|
""" |
41
|
|
|
def wrapper(*args, **kwargs): |
42
|
|
|
start = time.time() |
43
|
|
|
result = func(*args, **kwargs) |
44
|
|
|
end = time.time() |
45
|
|
|
logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2))) |
46
|
|
|
return result |
47
|
|
|
return wrapper |
48
|
|
|
|
49
|
|
|
|
50
|
|
|
def metric_type_to_str(metric_type): |
51
|
|
|
for key, value in METRIC_MAP.items(): |
52
|
|
|
if value == metric_type: |
53
|
|
|
return key |
54
|
|
|
raise Exception("metric_type: %s mapping not found" % metric_type) |
55
|
|
|
|
56
|
|
|
|
57
|
|
|
class MilvusClient(object): |
58
|
|
|
def __init__(self, collection_name=None, host=None, port=None, timeout=60): |
59
|
|
|
""" |
60
|
|
|
Milvus client wrapper for python-sdk. |
61
|
|
|
|
62
|
|
|
Default timeout set 60s |
63
|
|
|
""" |
64
|
|
|
self._collection_name = collection_name |
65
|
|
|
try: |
66
|
|
|
start_time = time.time() |
67
|
|
|
if not host: |
68
|
|
|
host = SERVER_HOST_DEFAULT |
69
|
|
|
if not port: |
70
|
|
|
port = SERVER_PORT_DEFAULT |
71
|
|
|
logger.debug(host) |
72
|
|
|
logger.debug(port) |
73
|
|
|
# retry connect for remote server |
74
|
|
|
i = 0 |
75
|
|
|
while time.time() < start_time + timeout: |
76
|
|
|
try: |
77
|
|
|
self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) |
78
|
|
|
if self._milvus.server_status(): |
79
|
|
|
logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) |
80
|
|
|
break |
81
|
|
|
except Exception as e: |
82
|
|
|
logger.debug("Milvus connect failed: %d times" % i) |
83
|
|
|
i = i + 1 |
84
|
|
|
|
85
|
|
|
if time.time() > start_time + timeout: |
86
|
|
|
raise Exception("Server connect timeout") |
87
|
|
|
|
88
|
|
|
except Exception as e: |
89
|
|
|
raise e |
90
|
|
|
self._metric_type = None |
91
|
|
|
if self._collection_name and self.exists_collection(): |
92
|
|
|
self._metric_type = metric_type_to_str(self.describe()[1].metric_type) |
93
|
|
|
self._dimension = self.describe()[1].dimension |
94
|
|
|
|
95
|
|
|
def __str__(self): |
96
|
|
|
return 'Milvus collection %s' % self._collection_name |
97
|
|
|
|
98
|
|
|
def set_collection(self, name): |
99
|
|
|
self._collection_name = name |
100
|
|
|
|
101
|
|
|
def check_status(self, status): |
102
|
|
|
if not status.OK(): |
103
|
|
|
logger.error(self._collection_name) |
104
|
|
|
logger.error(status.message) |
105
|
|
|
logger.error(self._milvus.server_status()) |
106
|
|
|
logger.error(self.count()) |
107
|
|
|
raise Exception("Status not ok") |
108
|
|
|
|
109
|
|
|
def check_result_ids(self, result): |
110
|
|
|
for index, item in enumerate(result): |
111
|
|
|
if item[0].distance >= epsilon: |
112
|
|
|
logger.error(index) |
113
|
|
|
logger.error(item[0].distance) |
114
|
|
|
raise Exception("Distance wrong") |
115
|
|
|
|
116
|
|
|
def create_collection(self, collection_name, dimension, index_file_size, metric_type): |
117
|
|
|
if not self._collection_name: |
118
|
|
|
self._collection_name = collection_name |
119
|
|
|
if metric_type not in METRIC_MAP.keys(): |
120
|
|
|
raise Exception("Not supported metric_type: %s" % metric_type) |
121
|
|
|
metric_type = METRIC_MAP[metric_type] |
122
|
|
|
create_param = {'collection_name': collection_name, |
123
|
|
|
'dimension': dimension, |
124
|
|
|
'index_file_size': index_file_size, |
125
|
|
|
"metric_type": metric_type} |
126
|
|
|
status = self._milvus.create_collection(create_param) |
127
|
|
|
self.check_status(status) |
128
|
|
|
|
129
|
|
|
def create_partition(self, tag_name): |
130
|
|
|
status = self._milvus.create_partition(self._collection_name, tag_name) |
131
|
|
|
self.check_status(status) |
132
|
|
|
|
133
|
|
|
def drop_partition(self, tag_name): |
134
|
|
|
status = self._milvus.drop_partition(self._collection_name, tag_name) |
135
|
|
|
self.check_status(status) |
136
|
|
|
|
137
|
|
|
def list_partitions(self): |
138
|
|
|
status, tags = self._milvus.list_partitions(self._collection_name) |
139
|
|
|
self.check_status(status) |
140
|
|
|
return tags |
141
|
|
|
|
142
|
|
|
@time_wrapper |
143
|
|
|
def insert(self, X, ids=None, collection_name=None): |
144
|
|
|
if collection_name is None: |
145
|
|
|
collection_name = self._collection_name |
146
|
|
|
status, result = self._milvus.insert(collection_name, X, ids) |
147
|
|
|
self.check_status(status) |
148
|
|
|
return status, result |
149
|
|
|
|
150
|
|
|
def insert_rand(self): |
151
|
|
|
insert_xb = random.randint(1, 100) |
152
|
|
|
X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)] |
153
|
|
|
X = utils.normalize(self._metric_type, X) |
154
|
|
|
count_before = self.count() |
155
|
|
|
status, _ = self.insert(X) |
156
|
|
|
self.check_status(status) |
157
|
|
|
self.flush() |
158
|
|
|
if count_before + insert_xb != self.count(): |
159
|
|
|
raise Exception("Assert failed after inserting") |
160
|
|
|
|
161
|
|
|
def get_rand_ids(self, length): |
162
|
|
|
while True: |
163
|
|
|
status, stats = self._milvus.get_collection_stats(self._collection_name) |
164
|
|
|
self.check_status(status) |
165
|
|
|
segments = stats["partitions"][0]["segments"] |
166
|
|
|
# random choice one segment |
167
|
|
|
segment = random.choice(segments) |
168
|
|
|
status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) |
169
|
|
|
if not status.OK(): |
170
|
|
|
logger.error(status.message) |
171
|
|
|
continue |
172
|
|
|
if len(segment_ids): |
173
|
|
|
break |
174
|
|
|
if length >= len(segment_ids): |
175
|
|
|
logger.debug("Reset length: %d" % len(segment_ids)) |
176
|
|
|
return segment_ids |
177
|
|
|
return random.sample(segment_ids, length) |
178
|
|
|
|
179
|
|
|
def get_rand_ids_each_segment(self, length): |
180
|
|
|
res = [] |
181
|
|
|
status, stats = self._milvus.get_collection_stats(self._collection_name) |
182
|
|
|
self.check_status(status) |
183
|
|
|
segments = stats["partitions"][0]["segments"] |
184
|
|
|
segments_num = len(segments) |
185
|
|
|
# random choice from each segment |
186
|
|
|
for segment in segments: |
187
|
|
|
status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) |
188
|
|
|
self.check_status(status) |
189
|
|
|
res.extend(segment_ids[:length]) |
190
|
|
|
return segments_num, res |
191
|
|
|
|
192
|
|
|
def get_rand_entities(self, length): |
193
|
|
|
ids = self.get_rand_ids(length) |
194
|
|
|
status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) |
195
|
|
|
self.check_status(status) |
196
|
|
|
return ids, get_res |
197
|
|
|
|
198
|
|
|
@time_wrapper |
199
|
|
|
def get_entities(self, get_ids): |
200
|
|
|
status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) |
201
|
|
|
self.check_status(status) |
202
|
|
|
return get_res |
203
|
|
|
|
204
|
|
|
@time_wrapper |
205
|
|
|
def delete(self, ids, collection_name=None): |
206
|
|
|
if collection_name is None: |
207
|
|
|
collection_name = self._collection_name |
208
|
|
|
status = self._milvus.delete_entity_by_id(collection_name, ids) |
209
|
|
|
self.check_status(status) |
210
|
|
|
|
211
|
|
|
def delete_rand(self): |
212
|
|
|
delete_id_length = random.randint(1, 100) |
213
|
|
|
count_before = self.count() |
214
|
|
|
logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length)) |
215
|
|
|
delete_ids = self.get_rand_ids(delete_id_length) |
216
|
|
|
self.delete(delete_ids) |
217
|
|
|
self.flush() |
218
|
|
|
logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) |
219
|
|
|
status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) |
220
|
|
|
self.check_status(status) |
221
|
|
|
for item in get_res: |
222
|
|
|
if item: |
223
|
|
|
raise Exception("Assert failed after delete") |
224
|
|
|
if count_before - len(delete_ids) != self.count(): |
225
|
|
|
raise Exception("Assert failed after delete") |
226
|
|
|
|
227
|
|
|
@time_wrapper |
228
|
|
|
def flush(self, collection_name=None): |
229
|
|
|
if collection_name is None: |
230
|
|
|
collection_name = self._collection_name |
231
|
|
|
status = self._milvus.flush([collection_name]) |
232
|
|
|
self.check_status(status) |
233
|
|
|
|
234
|
|
|
@time_wrapper |
235
|
|
|
def compact(self, collection_name=None): |
236
|
|
|
if collection_name is None: |
237
|
|
|
collection_name = self._collection_name |
238
|
|
|
status = self._milvus.compact(collection_name) |
239
|
|
|
self.check_status(status) |
240
|
|
|
|
241
|
|
|
@time_wrapper |
242
|
|
|
def create_index(self, index_type, index_param=None): |
243
|
|
|
index_type = INDEX_MAP[index_type] |
244
|
|
|
logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type)) |
245
|
|
|
if index_param: |
246
|
|
|
logger.info(index_param) |
247
|
|
|
status = self._milvus.create_index(self._collection_name, index_type, index_param) |
248
|
|
|
self.check_status(status) |
249
|
|
|
|
250
|
|
|
def describe_index(self): |
251
|
|
|
status, result = self._milvus.get_index_info(self._collection_name) |
252
|
|
|
self.check_status(status) |
253
|
|
|
index_type = None |
254
|
|
|
for k, v in INDEX_MAP.items(): |
255
|
|
|
if result._index_type == v: |
256
|
|
|
index_type = k |
257
|
|
|
break |
258
|
|
|
return {"index_type": index_type, "index_param": result._params} |
259
|
|
|
|
260
|
|
|
def drop_index(self): |
261
|
|
|
logger.info("Drop index: %s" % self._collection_name) |
262
|
|
|
return self._milvus.drop_index(self._collection_name) |
263
|
|
|
|
264
|
|
|
def query(self, X, top_k, search_param=None, collection_name=None): |
265
|
|
|
if collection_name is None: |
266
|
|
|
collection_name = self._collection_name |
267
|
|
|
status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param) |
268
|
|
|
self.check_status(status) |
269
|
|
|
return result |
270
|
|
|
|
271
|
|
|
def query_rand(self): |
272
|
|
|
top_k = random.randint(1, 100) |
273
|
|
|
nq = random.randint(1, 100) |
274
|
|
|
nprobe = random.randint(1, 100) |
275
|
|
|
search_param = {"nprobe": nprobe} |
276
|
|
|
_, X = self.get_rand_entities(nq) |
277
|
|
|
logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) |
278
|
|
|
status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param) |
279
|
|
|
self.check_status(status) |
280
|
|
|
# for i, item in enumerate(search_res): |
281
|
|
|
# if item[0].id != ids[i]: |
282
|
|
|
# logger.warning("The index of search result: %d" % i) |
283
|
|
|
# raise Exception("Query failed") |
284
|
|
|
|
285
|
|
|
# @time_wrapper |
286
|
|
|
# def query_ids(self, top_k, ids, search_param=None): |
287
|
|
|
# status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param) |
288
|
|
|
# self.check_result_ids(result) |
289
|
|
|
# return result |
290
|
|
|
|
291
|
|
|
def count(self, name=None): |
292
|
|
|
if name is None: |
293
|
|
|
name = self._collection_name |
294
|
|
|
logger.debug(self._milvus.count_entities(name)) |
295
|
|
|
row_count = self._milvus.count_entities(name)[1] |
296
|
|
|
if not row_count: |
297
|
|
|
row_count = 0 |
298
|
|
|
logger.debug("Row count: %d in collection: <%s>" % (row_count, name)) |
299
|
|
|
return row_count |
300
|
|
|
|
301
|
|
|
def drop(self, timeout=120, name=None): |
302
|
|
|
timeout = int(timeout) |
303
|
|
|
if name is None: |
304
|
|
|
name = self._collection_name |
305
|
|
|
logger.info("Start delete collection: %s" % name) |
306
|
|
|
status = self._milvus.drop_collection(name) |
307
|
|
|
self.check_status(status) |
308
|
|
|
i = 0 |
309
|
|
|
while i < timeout: |
310
|
|
|
if self.count(name=name): |
311
|
|
|
time.sleep(1) |
312
|
|
|
i = i + 1 |
313
|
|
|
continue |
314
|
|
|
else: |
315
|
|
|
break |
316
|
|
|
if i >= timeout: |
317
|
|
|
logger.error("Delete collection timeout") |
318
|
|
|
|
319
|
|
|
def describe(self): |
320
|
|
|
# logger.info(self._milvus.get_collection_info(self._collection_name)) |
321
|
|
|
return self._milvus.get_collection_info(self._collection_name) |
322
|
|
|
|
323
|
|
|
def show_collections(self): |
324
|
|
|
return self._milvus.list_collections() |
325
|
|
|
|
326
|
|
|
def exists_collection(self, collection_name=None): |
327
|
|
|
if collection_name is None: |
328
|
|
|
collection_name = self._collection_name |
329
|
|
|
_, res = self._milvus.has_collection(collection_name) |
330
|
|
|
# self.check_status(status) |
331
|
|
|
return res |
332
|
|
|
|
333
|
|
|
def clean_db(self): |
334
|
|
|
collection_names = self.show_collections()[1] |
335
|
|
|
for name in collection_names: |
336
|
|
|
logger.debug(name) |
337
|
|
|
self.drop(name=name) |
338
|
|
|
|
339
|
|
|
@time_wrapper |
340
|
|
|
def preload_collection(self): |
341
|
|
|
status = self._milvus.load_collection(self._collection_name, timeout=3000) |
342
|
|
|
self.check_status(status) |
343
|
|
|
return status |
344
|
|
|
|
345
|
|
|
def get_server_version(self): |
346
|
|
|
_, res = self._milvus.server_version() |
347
|
|
|
return res |
348
|
|
|
|
349
|
|
|
def get_server_mode(self): |
350
|
|
|
return self.cmd("mode") |
351
|
|
|
|
352
|
|
|
def get_server_commit(self): |
353
|
|
|
return self.cmd("build_commit_id") |
354
|
|
|
|
355
|
|
|
def get_server_config(self): |
356
|
|
|
return json.loads(self.cmd("get_config *")) |
357
|
|
|
|
358
|
|
|
def get_mem_info(self): |
359
|
|
|
result = json.loads(self.cmd("get_system_info")) |
360
|
|
|
result_human = { |
361
|
|
|
# unit: Gb |
362
|
|
|
"memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2) |
363
|
|
|
} |
364
|
|
|
return result_human |
365
|
|
|
|
366
|
|
|
def cmd(self, command): |
367
|
|
|
status, res = self._milvus._cmd(command) |
368
|
|
|
logger.info("Server command: %s, result: %s" % (command, res)) |
369
|
|
|
self.check_status(status) |
370
|
|
|
return res |
371
|
|
|
|