1
|
|
|
import os |
2
|
|
|
import logging |
3
|
|
|
import pdb |
4
|
|
|
import time |
5
|
|
|
import random |
6
|
|
|
from multiprocessing import Process |
7
|
|
|
from itertools import product |
8
|
|
|
import numpy as np |
9
|
|
|
import sklearn.preprocessing |
10
|
|
|
from client import MilvusClient |
11
|
|
|
import utils |
12
|
|
|
import parser |
13
|
|
|
|
14
|
|
|
logger = logging.getLogger("milvus_benchmark.runner") |
15
|
|
|
|
16
|
|
|
VECTORS_PER_FILE = 1000000 |
17
|
|
|
SIFT_VECTORS_PER_FILE = 100000 |
18
|
|
|
JACCARD_VECTORS_PER_FILE = 2000000 |
19
|
|
|
|
20
|
|
|
MAX_NQ = 10001 |
21
|
|
|
FILE_PREFIX = "binary_" |
22
|
|
|
|
23
|
|
|
# FOLDER_NAME = 'ann_1000m/source_data' |
24
|
|
|
SRC_BINARY_DATA_DIR = '/test/milvus/raw_data/random/' |
25
|
|
|
SIFT_SRC_DATA_DIR = '/test/milvus/raw_data/sift1b/' |
26
|
|
|
DEEP_SRC_DATA_DIR = '/test/milvus/raw_data/deep1b/' |
27
|
|
|
JACCARD_SRC_DATA_DIR = '/test/milvus/raw_data/jaccard/' |
28
|
|
|
HAMMING_SRC_DATA_DIR = '/test/milvus/raw_data/jaccard/' |
29
|
|
|
STRUCTURE_SRC_DATA_DIR = '/test/milvus/raw_data/jaccard/' |
30
|
|
|
SIFT_SRC_GROUNDTRUTH_DATA_DIR = SIFT_SRC_DATA_DIR + 'gnd' |
31
|
|
|
|
32
|
|
|
WARM_TOP_K = 1 |
33
|
|
|
WARM_NQ = 1 |
34
|
|
|
DEFAULT_DIM = 512 |
35
|
|
|
|
36
|
|
|
|
37
|
|
|
GROUNDTRUTH_MAP = { |
38
|
|
|
"1000000": "idx_1M.ivecs", |
39
|
|
|
"2000000": "idx_2M.ivecs", |
40
|
|
|
"5000000": "idx_5M.ivecs", |
41
|
|
|
"10000000": "idx_10M.ivecs", |
42
|
|
|
"20000000": "idx_20M.ivecs", |
43
|
|
|
"50000000": "idx_50M.ivecs", |
44
|
|
|
"100000000": "idx_100M.ivecs", |
45
|
|
|
"200000000": "idx_200M.ivecs", |
46
|
|
|
"500000000": "idx_500M.ivecs", |
47
|
|
|
"1000000000": "idx_1000M.ivecs", |
48
|
|
|
} |
49
|
|
|
|
50
|
|
|
|
51
|
|
|
def gen_file_name(idx, dimension, data_type): |
52
|
|
|
s = "%05d" % idx |
53
|
|
|
fname = FILE_PREFIX + str(dimension) + "d_" + s + ".npy" |
54
|
|
View Code Duplication |
if data_type == "random": |
|
|
|
|
55
|
|
|
fname = SRC_BINARY_DATA_DIR+fname |
56
|
|
|
elif data_type == "sift": |
57
|
|
|
fname = SIFT_SRC_DATA_DIR+fname |
58
|
|
|
elif data_type == "deep": |
59
|
|
|
fname = DEEP_SRC_DATA_DIR+fname |
60
|
|
|
elif data_type == "jaccard": |
61
|
|
|
fname = JACCARD_SRC_DATA_DIR+fname |
62
|
|
|
elif data_type == "hamming": |
63
|
|
|
fname = HAMMING_SRC_DATA_DIR+fname |
64
|
|
|
elif data_type == "sub" or data_type == "super": |
65
|
|
|
fname = STRUCTURE_SRC_DATA_DIR+fname |
66
|
|
|
return fname |
67
|
|
|
|
68
|
|
|
|
69
|
|
|
def get_vectors_from_binary(nq, dimension, data_type): |
70
|
|
|
# use the first file, nq should be less than VECTORS_PER_FILE |
71
|
|
|
if nq > MAX_NQ: |
72
|
|
|
raise Exception("Over size nq") |
73
|
|
View Code Duplication |
if data_type == "random": |
|
|
|
|
74
|
|
|
file_name = SRC_BINARY_DATA_DIR+'query_%d.npy' % dimension |
75
|
|
|
elif data_type == "sift": |
76
|
|
|
file_name = SIFT_SRC_DATA_DIR+'query.npy' |
77
|
|
|
elif data_type == "deep": |
78
|
|
|
file_name = DEEP_SRC_DATA_DIR+'query.npy' |
79
|
|
|
elif data_type == "jaccard": |
80
|
|
|
file_name = JACCARD_SRC_DATA_DIR+'query.npy' |
81
|
|
|
elif data_type == "hamming": |
82
|
|
|
file_name = HAMMING_SRC_DATA_DIR+'query.npy' |
83
|
|
|
elif data_type == "sub" or data_type == "super": |
84
|
|
|
file_name = STRUCTURE_SRC_DATA_DIR+'query.npy' |
85
|
|
|
data = np.load(file_name) |
|
|
|
|
86
|
|
|
vectors = data[0:nq].tolist() |
87
|
|
|
return vectors |
88
|
|
|
|
89
|
|
|
|
90
|
|
|
class Runner(object): |
91
|
|
|
def __init__(self): |
92
|
|
|
"""Run each tests defined in the suites. |
93
|
|
|
|
94
|
|
|
""" |
95
|
|
|
pass |
96
|
|
|
|
97
|
|
View Code Duplication |
def normalize(self, metric_type, X): |
|
|
|
|
98
|
|
|
if metric_type == "ip": |
99
|
|
|
logger.info("Set normalize for metric_type: %s" % metric_type) |
100
|
|
|
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') |
101
|
|
|
X = X.astype(np.float32) |
102
|
|
|
elif metric_type == "l2": |
103
|
|
|
X = X.astype(np.float32) |
104
|
|
|
elif metric_type in ["jaccard", "hamming", "sub", "super"]: |
105
|
|
|
tmp = [] |
106
|
|
|
for _, item in enumerate(X): |
107
|
|
|
new_vector = bytes(np.packbits(item, axis=-1).tolist()) |
108
|
|
|
tmp.append(new_vector) |
109
|
|
|
X = tmp |
110
|
|
|
return X |
111
|
|
|
|
112
|
|
|
def generate_combinations(self, args): |
113
|
|
|
if isinstance(args, list): |
114
|
|
|
args = [el if isinstance(el, list) else [el] for el in args] |
115
|
|
|
return [list(x) for x in product(*args)] |
116
|
|
|
elif isinstance(args, dict): |
117
|
|
|
flat = [] |
118
|
|
|
for k, v in args.items(): |
119
|
|
|
if isinstance(v, list): |
120
|
|
|
flat.append([(k, el) for el in v]) |
121
|
|
|
else: |
122
|
|
|
flat.append([(k, v)]) |
123
|
|
|
return [dict(x) for x in product(*flat)] |
124
|
|
|
else: |
125
|
|
|
raise TypeError("No args handling exists for %s" % type(args).__name__) |
126
|
|
|
|
127
|
|
|
def do_insert(self, milvus, collection_name, data_type, dimension, size, ni): |
128
|
|
|
''' |
129
|
|
|
@params: |
130
|
|
|
mivlus: server connect instance |
131
|
|
|
dimension: collection dimensionn |
132
|
|
|
# index_file_size: size trigger file merge |
133
|
|
|
size: row count of vectors to be insert |
134
|
|
|
ni: row count of vectors to be insert each time |
135
|
|
|
# store_id: if store the ids returned by call add_vectors or not |
136
|
|
|
@return: |
137
|
|
|
total_time: total time for all insert operation |
138
|
|
|
qps: vectors added per second |
139
|
|
|
ni_time: avarage insert operation time |
140
|
|
|
''' |
141
|
|
|
bi_res = {} |
142
|
|
|
total_time = 0.0 |
143
|
|
|
qps = 0.0 |
144
|
|
|
ni_time = 0.0 |
145
|
|
|
if data_type == "random": |
146
|
|
|
if dimension == 512: |
147
|
|
|
vectors_per_file = VECTORS_PER_FILE |
148
|
|
|
elif dimension == 4096: |
149
|
|
|
vectors_per_file = 100000 |
150
|
|
|
elif dimension == 16384: |
151
|
|
|
vectors_per_file = 10000 |
152
|
|
|
elif data_type == "sift": |
153
|
|
|
vectors_per_file = SIFT_VECTORS_PER_FILE |
154
|
|
|
elif data_type in ["jaccard", "hamming", "sub", "super"]: |
155
|
|
|
vectors_per_file = JACCARD_VECTORS_PER_FILE |
156
|
|
|
else: |
157
|
|
|
raise Exception("data_type: %s not supported" % data_type) |
158
|
|
|
if size % vectors_per_file or ni > vectors_per_file: |
|
|
|
|
159
|
|
|
raise Exception("Not invalid collection size or ni") |
160
|
|
|
file_num = size // vectors_per_file |
161
|
|
|
for i in range(file_num): |
162
|
|
|
file_name = gen_file_name(i, dimension, data_type) |
163
|
|
|
# logger.info("Load npy file: %s start" % file_name) |
164
|
|
|
data = np.load(file_name) |
165
|
|
|
# logger.info("Load npy file: %s end" % file_name) |
166
|
|
|
loops = vectors_per_file // ni |
167
|
|
|
for j in range(loops): |
168
|
|
|
vectors = data[j*ni:(j+1)*ni].tolist() |
169
|
|
|
if vectors: |
170
|
|
|
ni_start_time = time.time() |
171
|
|
|
# start insert vectors |
172
|
|
|
start_id = i * vectors_per_file + j * ni |
173
|
|
|
end_id = start_id + len(vectors) |
174
|
|
|
logger.info("Start id: %s, end id: %s" % (start_id, end_id)) |
175
|
|
|
ids = [k for k in range(start_id, end_id)] |
176
|
|
|
_, ids = milvus.insert(vectors, ids=ids) |
177
|
|
|
# milvus.flush() |
178
|
|
|
logger.debug(milvus.count()) |
179
|
|
|
ni_end_time = time.time() |
180
|
|
|
total_time = total_time + ni_end_time - ni_start_time |
181
|
|
|
|
182
|
|
|
qps = round(size / total_time, 2) |
183
|
|
|
ni_time = round(total_time / (loops * file_num), 2) |
|
|
|
|
184
|
|
|
bi_res["total_time"] = round(total_time, 2) |
185
|
|
|
bi_res["qps"] = qps |
186
|
|
|
bi_res["ni_time"] = ni_time |
187
|
|
|
return bi_res |
188
|
|
|
|
189
|
|
|
def do_query(self, milvus, collection_name, top_ks, nqs, run_count=1, search_param=None): |
190
|
|
|
bi_res = [] |
191
|
|
|
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name) |
192
|
|
|
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) |
193
|
|
|
for nq in nqs: |
194
|
|
|
tmp_res = [] |
195
|
|
|
vectors = base_query_vectors[0:nq] |
196
|
|
|
for top_k in top_ks: |
197
|
|
|
# avg_query_time = 0.0 |
198
|
|
|
min_query_time = 0.0 |
199
|
|
|
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) |
200
|
|
|
for i in range(run_count): |
201
|
|
|
logger.info("Start run query, run %d of %s" % (i+1, run_count)) |
202
|
|
|
start_time = time.time() |
203
|
|
|
milvus.query(vectors, top_k, search_param=search_param) |
204
|
|
|
interval_time = time.time() - start_time |
205
|
|
|
if (i == 0) or (min_query_time > interval_time): |
206
|
|
|
min_query_time = interval_time |
207
|
|
|
logger.info("Min query time: %.2f" % min_query_time) |
208
|
|
|
tmp_res.append(round(min_query_time, 2)) |
209
|
|
|
bi_res.append(tmp_res) |
210
|
|
|
return bi_res |
211
|
|
|
|
212
|
|
|
def do_query_qps(self, milvus, query_vectors, top_k, search_param): |
213
|
|
|
start_time = time.time() |
214
|
|
|
milvus.query(query_vectors, top_k, search_param) |
215
|
|
|
end_time = time.time() |
216
|
|
|
return end_time - start_time |
217
|
|
|
|
218
|
|
|
def do_query_ids(self, milvus, collection_name, top_k, nq, search_param=None): |
219
|
|
|
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name) |
220
|
|
|
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) |
221
|
|
|
vectors = base_query_vectors[0:nq] |
222
|
|
|
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) |
223
|
|
|
query_res = milvus.query(vectors, top_k, search_param=search_param) |
224
|
|
|
result_ids = [] |
225
|
|
|
result_distances = [] |
226
|
|
|
for result in query_res: |
227
|
|
|
tmp = [] |
228
|
|
|
tmp_distance = [] |
229
|
|
|
for item in result: |
230
|
|
|
tmp.append(item.id) |
231
|
|
|
tmp_distance.append(item.distance) |
232
|
|
|
result_ids.append(tmp) |
233
|
|
|
result_distances.append(tmp_distance) |
234
|
|
|
return result_ids, result_distances |
235
|
|
|
|
236
|
|
|
def do_query_acc(self, milvus, collection_name, top_k, nq, id_store_name, search_param=None): |
237
|
|
|
(data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name) |
238
|
|
|
base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) |
239
|
|
|
vectors = base_query_vectors[0:nq] |
240
|
|
|
logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) |
241
|
|
|
query_res = milvus.query(vectors, top_k, search_param=None) |
242
|
|
|
# if file existed, cover it |
243
|
|
|
if os.path.isfile(id_store_name): |
244
|
|
|
os.remove(id_store_name) |
245
|
|
|
with open(id_store_name, 'a+') as fd: |
246
|
|
|
for nq_item in query_res: |
247
|
|
|
for item in nq_item: |
248
|
|
|
fd.write(str(item.id)+'\t') |
249
|
|
|
fd.write('\n') |
250
|
|
|
|
251
|
|
|
# compute and print accuracy |
252
|
|
|
def compute_accuracy(self, flat_file_name, index_file_name): |
253
|
|
|
flat_id_list = []; index_id_list = [] |
254
|
|
|
logger.info("Loading flat id file: %s" % flat_file_name) |
255
|
|
|
with open(flat_file_name, 'r') as flat_id_fd: |
256
|
|
|
for line in flat_id_fd: |
257
|
|
|
tmp_list = line.strip("\n").strip().split("\t") |
258
|
|
|
flat_id_list.append(tmp_list) |
259
|
|
|
logger.info("Loading index id file: %s" % index_file_name) |
260
|
|
|
with open(index_file_name) as index_id_fd: |
261
|
|
|
for line in index_id_fd: |
262
|
|
|
tmp_list = line.strip("\n").strip().split("\t") |
263
|
|
|
index_id_list.append(tmp_list) |
264
|
|
|
if len(flat_id_list) != len(index_id_list): |
265
|
|
|
raise Exception("Flat index result length: <flat: %s, index: %s> not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list))) |
266
|
|
|
# get the accuracy |
267
|
|
|
return self.get_recall_value(flat_id_list, index_id_list) |
268
|
|
|
|
269
|
|
|
def get_recall_value(self, true_ids, result_ids): |
270
|
|
|
""" |
271
|
|
|
Use the intersection length |
272
|
|
|
""" |
273
|
|
|
sum_radio = 0.0 |
274
|
|
|
for index, item in enumerate(result_ids): |
275
|
|
|
# tmp = set(item).intersection(set(flat_id_list[index])) |
276
|
|
|
tmp = set(true_ids[index]).intersection(set(item)) |
277
|
|
|
sum_radio = sum_radio + len(tmp) / len(item) |
278
|
|
|
# logger.debug(sum_radio) |
279
|
|
|
return round(sum_radio / len(result_ids), 3) |
280
|
|
|
|
281
|
|
|
""" |
282
|
|
|
Implementation based on: |
283
|
|
|
https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py |
284
|
|
|
""" |
285
|
|
|
def get_groundtruth_ids(self, collection_size): |
286
|
|
|
fname = GROUNDTRUTH_MAP[str(collection_size)] |
287
|
|
|
fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname |
288
|
|
|
a = np.fromfile(fname, dtype='int32') |
289
|
|
|
d = a[0] |
290
|
|
|
true_ids = a.reshape(-1, d + 1)[:, 1:].copy() |
291
|
|
|
return true_ids |
292
|
|
|
|