1 | import os |
||
2 | import logging |
||
3 | import pdb |
||
4 | import time |
||
5 | import random |
||
6 | from multiprocessing import Process |
||
7 | from itertools import product |
||
8 | import numpy as np |
||
9 | import sklearn.preprocessing |
||
10 | from client import MilvusClient |
||
11 | import utils |
||
12 | import parser |
||
13 | |||
14 | logger = logging.getLogger("milvus_benchmark.runner") |
||
15 | |||
16 | VECTORS_PER_FILE = 1000000 |
||
17 | SIFT_VECTORS_PER_FILE = 100000 |
||
18 | JACCARD_VECTORS_PER_FILE = 2000000 |
||
19 | |||
20 | MAX_NQ = 10001 |
||
21 | FILE_PREFIX = "binary_" |
||
22 | |||
23 | # FOLDER_NAME = 'ann_1000m/source_data' |
||
24 | SRC_BINARY_DATA_DIR = '/test/milvus/raw_data/random/' |
||
25 | SIFT_SRC_DATA_DIR = '/test/milvus/raw_data/sift1b/' |
||
26 | DEEP_SRC_DATA_DIR = '/test/milvus/raw_data/deep1b/' |
||
27 | JACCARD_SRC_DATA_DIR = '/test/milvus/raw_data/jaccard/' |
||
28 | HAMMING_SRC_DATA_DIR = '/test/milvus/raw_data/jaccard/' |
||
29 | STRUCTURE_SRC_DATA_DIR = '/test/milvus/raw_data/jaccard/' |
||
30 | SIFT_SRC_GROUNDTRUTH_DATA_DIR = SIFT_SRC_DATA_DIR + 'gnd' |
||
31 | |||
32 | WARM_TOP_K = 1 |
||
33 | WARM_NQ = 1 |
||
34 | DEFAULT_DIM = 512 |
||
35 | |||
36 | |||
37 | GROUNDTRUTH_MAP = { |
||
38 | "1000000": "idx_1M.ivecs", |
||
39 | "2000000": "idx_2M.ivecs", |
||
40 | "5000000": "idx_5M.ivecs", |
||
41 | "10000000": "idx_10M.ivecs", |
||
42 | "20000000": "idx_20M.ivecs", |
||
43 | "50000000": "idx_50M.ivecs", |
||
44 | "100000000": "idx_100M.ivecs", |
||
45 | "200000000": "idx_200M.ivecs", |
||
46 | "500000000": "idx_500M.ivecs", |
||
47 | "1000000000": "idx_1000M.ivecs", |
||
48 | } |
||
49 | |||
50 | |||
51 | def gen_file_name(idx, dimension, data_type): |
||
52 | s = "%05d" % idx |
||
53 | fname = FILE_PREFIX + str(dimension) + "d_" + s + ".npy" |
||
54 | View Code Duplication | if data_type == "random": |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
55 | fname = SRC_BINARY_DATA_DIR+fname |
||
56 | elif data_type == "sift": |
||
57 | fname = SIFT_SRC_DATA_DIR+fname |
||
58 | elif data_type == "deep": |
||
59 | fname = DEEP_SRC_DATA_DIR+fname |
||
60 | elif data_type == "jaccard": |
||
61 | fname = JACCARD_SRC_DATA_DIR+fname |
||
62 | elif data_type == "hamming": |
||
63 | fname = HAMMING_SRC_DATA_DIR+fname |
||
64 | elif data_type == "sub" or data_type == "super": |
||
65 | fname = STRUCTURE_SRC_DATA_DIR+fname |
||
66 | return fname |
||
67 | |||
68 | |||
69 | def get_vectors_from_binary(nq, dimension, data_type): |
||
70 | # use the first file, nq should be less than VECTORS_PER_FILE |
||
71 | if nq > MAX_NQ: |
||
72 | raise Exception("Over size nq") |
||
73 | View Code Duplication | if data_type == "random": |
|
0 ignored issues
–
show
|
|||
74 | file_name = SRC_BINARY_DATA_DIR+'query_%d.npy' % dimension |
||
75 | elif data_type == "sift": |
||
76 | file_name = SIFT_SRC_DATA_DIR+'query.npy' |
||
77 | elif data_type == "deep": |
||
78 | file_name = DEEP_SRC_DATA_DIR+'query.npy' |
||
79 | elif data_type == "jaccard": |
||
80 | file_name = JACCARD_SRC_DATA_DIR+'query.npy' |
||
81 | elif data_type == "hamming": |
||
82 | file_name = HAMMING_SRC_DATA_DIR+'query.npy' |
||
83 | elif data_type == "sub" or data_type == "super": |
||
84 | file_name = STRUCTURE_SRC_DATA_DIR+'query.npy' |
||
85 | data = np.load(file_name) |
||
0 ignored issues
–
show
|
|||
86 | vectors = data[0:nq].tolist() |
||
87 | return vectors |
||
88 | |||
89 | |||
90 | class Runner(object): |
||
91 | def __init__(self): |
||
92 | """Run each tests defined in the suites. |
||
93 | |||
94 | """ |
||
95 | pass |
||
96 | |||
97 | View Code Duplication | def normalize(self, metric_type, X): |
|
0 ignored issues
–
show
|
|||
98 | if metric_type == "ip": |
||
99 | logger.info("Set normalize for metric_type: %s" % metric_type) |
||
100 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') |
||
101 | X = X.astype(np.float32) |
||
102 | elif metric_type == "l2": |
||
103 | X = X.astype(np.float32) |
||
104 | elif metric_type in ["jaccard", "hamming", "sub", "super"]: |
||
105 | tmp = [] |
||
106 | for _, item in enumerate(X): |
||
107 | new_vector = bytes(np.packbits(item, axis=-1).tolist()) |
||
108 | tmp.append(new_vector) |
||
109 | X = tmp |
||
110 | return X |
||
111 | |||
112 | def generate_combinations(self, args): |
||
113 | if isinstance(args, list): |
||
114 | args = [el if isinstance(el, list) else [el] for el in args] |
||
115 | return [list(x) for x in product(*args)] |
||
116 | elif isinstance(args, dict): |
||
117 | flat = [] |
||
118 | for k, v in args.items(): |
||
119 | if isinstance(v, list): |
||
120 | flat.append([(k, el) for el in v]) |
||
121 | else: |
||
122 | flat.append([(k, v)]) |
||
123 | return [dict(x) for x in product(*flat)] |
||
124 | else: |
||
125 | raise TypeError("No args handling exists for %s" % type(args).__name__) |
||
126 | |||
127 | def do_insert(self, milvus, collection_name, data_type, dimension, size, ni): |
||
128 | ''' |
||
129 | @params: |
||
130 | mivlus: server connect instance |
||
131 | dimension: collection dimensionn |
||
132 | # index_file_size: size trigger file merge |
||
133 | size: row count of vectors to be insert |
||
134 | ni: row count of vectors to be insert each time |
||
135 | # store_id: if store the ids returned by call add_vectors or not |
||
136 | @return: |
||
137 | total_time: total time for all insert operation |
||
138 | qps: vectors added per second |
||
139 | ni_time: avarage insert operation time |
||
140 | ''' |
||
141 | bi_res = {} |
||
142 | total_time = 0.0 |
||
143 | qps = 0.0 |
||
144 | ni_time = 0.0 |
||
145 | if data_type == "random": |
||
146 | if dimension == 512: |
||
147 | vectors_per_file = VECTORS_PER_FILE |
||
148 | elif dimension == 4096: |
||
149 | vectors_per_file = 100000 |
||
150 | elif dimension == 16384: |
||
151 | vectors_per_file = 10000 |
||
152 | elif data_type == "sift": |
||
153 | vectors_per_file = SIFT_VECTORS_PER_FILE |
||
154 | elif data_type in ["jaccard", "hamming", "sub", "super"]: |
||
155 | vectors_per_file = JACCARD_VECTORS_PER_FILE |
||
156 | else: |
||
157 | raise Exception("data_type: %s not supported" % data_type) |
||
158 | if size % vectors_per_file or ni > vectors_per_file: |
||
0 ignored issues
–
show
|
|||
159 | raise Exception("Not invalid collection size or ni") |
||
160 | file_num = size // vectors_per_file |
||
161 | for i in range(file_num): |
||
162 | file_name = gen_file_name(i, dimension, data_type) |
||
163 | # logger.info("Load npy file: %s start" % file_name) |
||
164 | data = np.load(file_name) |
||
165 | # logger.info("Load npy file: %s end" % file_name) |
||
166 | loops = vectors_per_file // ni |
||
167 | for j in range(loops): |
||
168 | vectors = data[j*ni:(j+1)*ni].tolist() |
||
169 | if vectors: |
||
170 | ni_start_time = time.time() |
||
171 | # start insert vectors |
||
172 | start_id = i * vectors_per_file + j * ni |
||
173 | end_id = start_id + len(vectors) |
||
174 | logger.info("Start id: %s, end id: %s" % (start_id, end_id)) |
||
175 | ids = [k for k in range(start_id, end_id)] |
||
176 | _, ids = milvus.insert(vectors, ids=ids) |
||
177 | # milvus.flush() |
||
178 | logger.debug(milvus.count()) |
||
179 | ni_end_time = time.time() |
||
180 | total_time = total_time + ni_end_time - ni_start_time |
||
181 | |||
182 | qps = round(size / total_time, 2) |
||
183 | ni_time = round(total_time / (loops * file_num), 2) |
||
0 ignored issues
–
show
|
|||
184 | bi_res["total_time"] = round(total_time, 2) |
||
185 | bi_res["qps"] = qps |
||
186 | bi_res["ni_time"] = ni_time |
||
187 | return bi_res |
||
188 | |||
189 | def do_query(self, milvus, collection_name, top_ks, nqs, run_count=1, search_param=None): |
||
190 | bi_res = [] |
||
191 | (data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name) |
||
192 | base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) |
||
193 | for nq in nqs: |
||
194 | tmp_res = [] |
||
195 | vectors = base_query_vectors[0:nq] |
||
196 | for top_k in top_ks: |
||
197 | # avg_query_time = 0.0 |
||
198 | min_query_time = 0.0 |
||
199 | logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) |
||
200 | for i in range(run_count): |
||
201 | logger.info("Start run query, run %d of %s" % (i+1, run_count)) |
||
202 | start_time = time.time() |
||
203 | milvus.query(vectors, top_k, search_param=search_param) |
||
204 | interval_time = time.time() - start_time |
||
205 | if (i == 0) or (min_query_time > interval_time): |
||
206 | min_query_time = interval_time |
||
207 | logger.info("Min query time: %.2f" % min_query_time) |
||
208 | tmp_res.append(round(min_query_time, 2)) |
||
209 | bi_res.append(tmp_res) |
||
210 | return bi_res |
||
211 | |||
212 | def do_query_qps(self, milvus, query_vectors, top_k, search_param): |
||
213 | start_time = time.time() |
||
214 | milvus.query(query_vectors, top_k, search_param) |
||
215 | end_time = time.time() |
||
216 | return end_time - start_time |
||
217 | |||
218 | def do_query_ids(self, milvus, collection_name, top_k, nq, search_param=None): |
||
219 | (data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name) |
||
220 | base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) |
||
221 | vectors = base_query_vectors[0:nq] |
||
222 | logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) |
||
223 | query_res = milvus.query(vectors, top_k, search_param=search_param) |
||
224 | result_ids = [] |
||
225 | result_distances = [] |
||
226 | for result in query_res: |
||
227 | tmp = [] |
||
228 | tmp_distance = [] |
||
229 | for item in result: |
||
230 | tmp.append(item.id) |
||
231 | tmp_distance.append(item.distance) |
||
232 | result_ids.append(tmp) |
||
233 | result_distances.append(tmp_distance) |
||
234 | return result_ids, result_distances |
||
235 | |||
236 | def do_query_acc(self, milvus, collection_name, top_k, nq, id_store_name, search_param=None): |
||
237 | (data_type, collection_size, index_file_size, dimension, metric_type) = parser.collection_parser(collection_name) |
||
238 | base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) |
||
239 | vectors = base_query_vectors[0:nq] |
||
240 | logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) |
||
241 | query_res = milvus.query(vectors, top_k, search_param=None) |
||
242 | # if file existed, cover it |
||
243 | if os.path.isfile(id_store_name): |
||
244 | os.remove(id_store_name) |
||
245 | with open(id_store_name, 'a+') as fd: |
||
246 | for nq_item in query_res: |
||
247 | for item in nq_item: |
||
248 | fd.write(str(item.id)+'\t') |
||
249 | fd.write('\n') |
||
250 | |||
251 | # compute and print accuracy |
||
252 | def compute_accuracy(self, flat_file_name, index_file_name): |
||
253 | flat_id_list = []; index_id_list = [] |
||
254 | logger.info("Loading flat id file: %s" % flat_file_name) |
||
255 | with open(flat_file_name, 'r') as flat_id_fd: |
||
256 | for line in flat_id_fd: |
||
257 | tmp_list = line.strip("\n").strip().split("\t") |
||
258 | flat_id_list.append(tmp_list) |
||
259 | logger.info("Loading index id file: %s" % index_file_name) |
||
260 | with open(index_file_name) as index_id_fd: |
||
261 | for line in index_id_fd: |
||
262 | tmp_list = line.strip("\n").strip().split("\t") |
||
263 | index_id_list.append(tmp_list) |
||
264 | if len(flat_id_list) != len(index_id_list): |
||
265 | raise Exception("Flat index result length: <flat: %s, index: %s> not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list))) |
||
266 | # get the accuracy |
||
267 | return self.get_recall_value(flat_id_list, index_id_list) |
||
268 | |||
269 | def get_recall_value(self, true_ids, result_ids): |
||
270 | """ |
||
271 | Use the intersection length |
||
272 | """ |
||
273 | sum_radio = 0.0 |
||
274 | for index, item in enumerate(result_ids): |
||
275 | # tmp = set(item).intersection(set(flat_id_list[index])) |
||
276 | tmp = set(true_ids[index]).intersection(set(item)) |
||
277 | sum_radio = sum_radio + len(tmp) / len(item) |
||
278 | # logger.debug(sum_radio) |
||
279 | return round(sum_radio / len(result_ids), 3) |
||
280 | |||
281 | """ |
||
282 | Implementation based on: |
||
283 | https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py |
||
284 | """ |
||
285 | def get_groundtruth_ids(self, collection_size): |
||
286 | fname = GROUNDTRUTH_MAP[str(collection_size)] |
||
287 | fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname |
||
288 | a = np.fromfile(fname, dtype='int32') |
||
289 | d = a[0] |
||
290 | true_ids = a.reshape(-1, d + 1)[:, 1:].copy() |
||
291 | return true_ids |
||
292 |