1
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates. |
2
|
|
|
# |
3
|
|
|
# This source code is licensed under the MIT license found in the |
4
|
|
|
# LICENSE file in the root directory of this source tree. |
5
|
|
|
|
6
|
|
|
#! /usr/bin/env python2 |
7
|
|
|
|
8
|
|
|
import os |
9
|
|
|
import numpy as np |
10
|
|
|
from matplotlib import pyplot |
11
|
|
|
|
12
|
|
|
import re |
13
|
|
|
|
14
|
|
|
from argparse import Namespace |
15
|
|
|
|
16
|
|
|
|
17
|
|
|
# the directory used in run_on_cluster.bash |
18
|
|
|
basedir = '/mnt/vol/gfsai-east/ai-group/users/matthijs/bench_all_ivf/' |
19
|
|
|
logdir = basedir + 'logs/' |
20
|
|
|
|
21
|
|
|
|
22
|
|
|
# which plot to output |
23
|
|
|
db = 'bigann1B' |
24
|
|
|
code_size = 8 |
25
|
|
|
|
26
|
|
|
|
27
|
|
|
|
28
|
|
|
def unitsize(indexkey): |
29
|
|
|
""" size of one vector in the index """ |
30
|
|
|
mo = re.match('.*,PQ(\\d+)', indexkey) |
31
|
|
|
if mo: |
32
|
|
|
return int(mo.group(1)) |
33
|
|
|
if indexkey.endswith('SQ8'): |
34
|
|
|
bits_per_d = 8 |
35
|
|
|
elif indexkey.endswith('SQ4'): |
36
|
|
|
bits_per_d = 4 |
37
|
|
|
elif indexkey.endswith('SQfp16'): |
38
|
|
|
bits_per_d = 16 |
39
|
|
|
else: |
40
|
|
|
assert False |
41
|
|
|
mo = re.match('PCAR(\\d+),.*', indexkey) |
42
|
|
|
if mo: |
43
|
|
|
return bits_per_d * int(mo.group(1)) / 8 |
|
|
|
|
44
|
|
|
mo = re.match('OPQ\\d+_(\\d+),.*', indexkey) |
45
|
|
|
if mo: |
46
|
|
|
return bits_per_d * int(mo.group(1)) / 8 |
47
|
|
|
mo = re.match('RR(\\d+),.*', indexkey) |
48
|
|
|
if mo: |
49
|
|
|
return bits_per_d * int(mo.group(1)) / 8 |
50
|
|
|
assert False |
51
|
|
|
|
52
|
|
|
|
53
|
|
|
def dbsize_from_name(dbname): |
54
|
|
|
sufs = { |
55
|
|
|
'1B': 10**9, |
56
|
|
|
'100M': 10**8, |
57
|
|
|
'10M': 10**7, |
58
|
|
|
'1M': 10**6, |
59
|
|
|
} |
60
|
|
|
for s in sufs: |
61
|
|
|
if dbname.endswith(s): |
62
|
|
|
return sufs[s] |
63
|
|
|
else: |
64
|
|
|
assert False |
65
|
|
|
|
66
|
|
|
|
67
|
|
|
def keep_latest_stdout(fnames): |
68
|
|
|
fnames = [fname for fname in fnames if fname.endswith('.stdout')] |
69
|
|
|
fnames.sort() |
70
|
|
|
n = len(fnames) |
71
|
|
|
fnames2 = [] |
72
|
|
|
for i, fname in enumerate(fnames): |
73
|
|
|
if i + 1 < n and fnames[i + 1][:-8] == fname[:-8]: |
74
|
|
|
continue |
75
|
|
|
fnames2.append(fname) |
76
|
|
|
return fnames2 |
77
|
|
|
|
78
|
|
|
|
79
|
|
|
def parse_result_file(fname): |
80
|
|
|
# print fname |
81
|
|
|
st = 0 |
82
|
|
|
res = [] |
83
|
|
|
keys = [] |
84
|
|
|
stats = {} |
85
|
|
|
stats['run_version'] = fname[-8] |
86
|
|
|
for l in open(fname): |
87
|
|
|
if st == 0: |
88
|
|
|
if l.startswith('CHRONOS_JOB_INSTANCE_ID'): |
89
|
|
|
stats['CHRONOS_JOB_INSTANCE_ID'] = l.split()[-1] |
90
|
|
|
if l.startswith('index size on disk:'): |
91
|
|
|
stats['index_size'] = int(l.split()[-1]) |
92
|
|
|
if l.startswith('current RSS:'): |
93
|
|
|
stats['RSS'] = int(l.split()[-1]) |
94
|
|
|
if l.startswith('precomputed tables size:'): |
95
|
|
|
stats['tables_size'] = int(l.split()[-1]) |
96
|
|
|
if l.startswith('Setting nb of threads to'): |
97
|
|
|
stats['n_threads'] = int(l.split()[-1]) |
98
|
|
|
if l.startswith(' add in'): |
99
|
|
|
stats['add_time'] = float(l.split()[-2]) |
100
|
|
|
if l.startswith('args:'): |
101
|
|
|
args = eval(l[l.find(' '):]) |
102
|
|
|
indexkey = args.indexkey |
103
|
|
|
elif 'R@1 R@10 R@100' in l: |
104
|
|
|
st = 1 |
105
|
|
|
elif 'index size on disk:' in l: |
106
|
|
|
index_size = int(l.split()[-1]) |
107
|
|
|
elif st == 1: |
108
|
|
|
st = 2 |
109
|
|
|
elif st == 2: |
110
|
|
|
fi = l.split() |
111
|
|
|
keys.append(fi[0]) |
112
|
|
|
res.append([float(x) for x in fi[1:]]) |
113
|
|
|
return indexkey, np.array(res), keys, stats |
|
|
|
|
114
|
|
|
|
115
|
|
|
# run parsing |
116
|
|
|
allres = {} |
117
|
|
|
allstats = {} |
118
|
|
|
nts = [] |
119
|
|
|
missing = [] |
120
|
|
|
versions = {} |
121
|
|
|
|
122
|
|
|
fnames = keep_latest_stdout(os.listdir(logdir)) |
123
|
|
|
# print fnames |
124
|
|
|
# filenames are in the form <key>.x.stdout |
125
|
|
|
# where x is a version number (from a to z) |
126
|
|
|
# keep only latest version of each name |
127
|
|
|
|
128
|
|
|
for fname in fnames: |
129
|
|
|
if not ('db' + db in fname and fname.endswith('.stdout')): |
130
|
|
|
continue |
131
|
|
|
indexkey, res, _, stats = parse_result_file(logdir + fname) |
132
|
|
|
if res.size == 0: |
133
|
|
|
missing.append(fname) |
134
|
|
|
errorline = open( |
135
|
|
|
logdir + fname.replace('.stdout', '.stderr')).readlines() |
136
|
|
|
if len(errorline) > 0: |
137
|
|
|
errorline = errorline[-1] |
138
|
|
|
else: |
139
|
|
|
errorline = 'NO STDERR' |
140
|
|
|
print fname, stats['CHRONOS_JOB_INSTANCE_ID'], errorline |
141
|
|
|
|
142
|
|
|
else: |
143
|
|
|
if indexkey in allres: |
144
|
|
|
if allstats[indexkey]['run_version'] > stats['run_version']: |
145
|
|
|
# don't use this run |
146
|
|
|
continue |
147
|
|
|
n_threads = stats.get('n_threads', 1) |
148
|
|
|
nts.append(n_threads) |
149
|
|
|
allres[indexkey] = res |
150
|
|
|
allstats[indexkey] = stats |
151
|
|
|
|
152
|
|
|
assert len(set(nts)) == 1 |
153
|
|
|
n_threads = nts[0] |
154
|
|
|
|
155
|
|
|
|
156
|
|
|
def plot_tradeoffs(allres, code_size, recall_rank): |
157
|
|
|
dbsize = dbsize_from_name(db) |
158
|
|
|
recall_idx = int(np.log10(recall_rank)) |
159
|
|
|
|
160
|
|
|
bigtab = [] |
161
|
|
|
names = [] |
162
|
|
|
|
163
|
|
|
for k,v in sorted(allres.items()): |
164
|
|
|
if v.ndim != 2: continue |
165
|
|
|
us = unitsize(k) |
166
|
|
|
if us != code_size: continue |
167
|
|
|
perf = v[:, recall_idx] |
168
|
|
|
times = v[:, 3] |
169
|
|
|
bigtab.append( |
170
|
|
|
np.vstack(( |
171
|
|
|
np.ones(times.size, dtype=int) * len(names), |
172
|
|
|
perf, times |
173
|
|
|
)) |
174
|
|
|
) |
175
|
|
|
names.append(k) |
176
|
|
|
|
177
|
|
|
bigtab = np.hstack(bigtab) |
178
|
|
|
|
179
|
|
|
perm = np.argsort(bigtab[1, :]) |
180
|
|
|
bigtab = bigtab[:, perm] |
181
|
|
|
|
182
|
|
|
times = np.minimum.accumulate(bigtab[2, ::-1])[::-1] |
183
|
|
|
selection = np.where(bigtab[2, :] == times) |
184
|
|
|
|
185
|
|
|
selected_methods = [names[i] for i in |
186
|
|
|
np.unique(bigtab[0, selection].astype(int))] |
187
|
|
|
not_selected = list(set(names) - set(selected_methods)) |
188
|
|
|
|
189
|
|
|
print "methods without an optimal OP: ", not_selected |
190
|
|
|
|
191
|
|
|
nq = 10000 |
192
|
|
|
pyplot.title('database ' + db + ' code_size=%d' % code_size) |
193
|
|
|
|
194
|
|
|
# grayed out lines |
195
|
|
|
|
196
|
|
|
for k in not_selected: |
197
|
|
|
v = allres[k] |
198
|
|
|
if v.ndim != 2: continue |
199
|
|
|
us = unitsize(k) |
200
|
|
|
if us != code_size: continue |
201
|
|
|
|
202
|
|
|
linestyle = (':' if 'PQ' in k else |
203
|
|
|
'-.' if 'SQ4' in k else |
204
|
|
|
'--' if 'SQ8' in k else '-') |
205
|
|
|
|
206
|
|
|
pyplot.semilogy(v[:, recall_idx], v[:, 3], label=None, |
207
|
|
|
linestyle=linestyle, |
208
|
|
|
marker='o' if 'HNSW' in k else '+', |
209
|
|
|
color='#cccccc', linewidth=0.2) |
210
|
|
|
|
211
|
|
|
# important methods |
212
|
|
|
for k in selected_methods: |
213
|
|
|
v = allres[k] |
214
|
|
|
if v.ndim != 2: continue |
215
|
|
|
us = unitsize(k) |
216
|
|
|
if us != code_size: continue |
217
|
|
|
|
218
|
|
|
stats = allstats[k] |
219
|
|
|
tot_size = stats['index_size'] + stats['tables_size'] |
220
|
|
|
id_size = 8 # 64 bit |
221
|
|
|
|
222
|
|
|
addt = '' |
223
|
|
|
if 'add_time' in stats: |
224
|
|
|
add_time = stats['add_time'] |
225
|
|
|
if add_time > 7200: |
226
|
|
|
add_min = add_time / 60 |
227
|
|
|
addt = ', %dh%02d' % (add_min / 60, add_min % 60) |
228
|
|
|
else: |
229
|
|
|
add_sec = int(add_time) |
230
|
|
|
addt = ', %dm%02d' % (add_sec / 60, add_sec % 60) |
231
|
|
|
|
232
|
|
|
|
233
|
|
|
label = k + ' (size+%.1f%%%s)' % ( |
234
|
|
|
tot_size / float((code_size + id_size) * dbsize) * 100 - 100, |
235
|
|
|
addt) |
236
|
|
|
|
237
|
|
|
linestyle = (':' if 'PQ' in k else |
238
|
|
|
'-.' if 'SQ4' in k else |
239
|
|
|
'--' if 'SQ8' in k else '-') |
240
|
|
|
|
241
|
|
|
pyplot.semilogy(v[:, recall_idx], v[:, 3], label=label, |
242
|
|
|
linestyle=linestyle, |
243
|
|
|
marker='o' if 'HNSW' in k else '+') |
244
|
|
|
|
245
|
|
|
if len(not_selected) == 0: |
246
|
|
|
om = '' |
247
|
|
|
else: |
248
|
|
|
om = '\nomitted:' |
249
|
|
|
nc = len(om) |
250
|
|
|
for m in not_selected: |
251
|
|
|
if nc > 80: |
252
|
|
|
om += '\n' |
253
|
|
|
nc = 0 |
254
|
|
|
om += ' ' + m |
255
|
|
|
nc += len(m) + 1 |
256
|
|
|
|
257
|
|
|
pyplot.xlabel('1-recall at %d %s' % (recall_rank, om) ) |
258
|
|
|
pyplot.ylabel('search time per query (ms, %d threads)' % n_threads) |
259
|
|
|
pyplot.legend() |
260
|
|
|
pyplot.grid() |
261
|
|
|
pyplot.savefig('figs/tradeoffs_%s_cs%d_r%d.png' % ( |
262
|
|
|
db, code_size, recall_rank)) |
263
|
|
|
return selected_methods, not_selected |
264
|
|
|
|
265
|
|
|
|
266
|
|
|
pyplot.gcf().set_size_inches(15, 10) |
267
|
|
|
|
268
|
|
|
plot_tradeoffs(allres, code_size=code_size, recall_rank=1) |
269
|
|
|
|