|
1
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates. |
|
2
|
|
|
# |
|
3
|
|
|
# This source code is licensed under the MIT license found in the |
|
4
|
|
|
# LICENSE file in the root directory of this source tree. |
|
5
|
|
|
|
|
6
|
|
|
#! /usr/bin/env python2 |
|
7
|
|
|
|
|
8
|
|
|
import os |
|
9
|
|
|
import numpy as np |
|
10
|
|
|
from matplotlib import pyplot |
|
11
|
|
|
|
|
12
|
|
|
import re |
|
13
|
|
|
|
|
14
|
|
|
from argparse import Namespace |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
# the directory used in run_on_cluster.bash |
|
18
|
|
|
basedir = '/mnt/vol/gfsai-east/ai-group/users/matthijs/bench_all_ivf/' |
|
19
|
|
|
logdir = basedir + 'logs/' |
|
20
|
|
|
|
|
21
|
|
|
|
|
22
|
|
|
# which plot to output |
|
23
|
|
|
db = 'bigann1B' |
|
24
|
|
|
code_size = 8 |
|
25
|
|
|
|
|
26
|
|
|
|
|
27
|
|
|
|
|
28
|
|
|
def unitsize(indexkey): |
|
29
|
|
|
""" size of one vector in the index """ |
|
30
|
|
|
mo = re.match('.*,PQ(\\d+)', indexkey) |
|
31
|
|
|
if mo: |
|
32
|
|
|
return int(mo.group(1)) |
|
33
|
|
|
if indexkey.endswith('SQ8'): |
|
34
|
|
|
bits_per_d = 8 |
|
35
|
|
|
elif indexkey.endswith('SQ4'): |
|
36
|
|
|
bits_per_d = 4 |
|
37
|
|
|
elif indexkey.endswith('SQfp16'): |
|
38
|
|
|
bits_per_d = 16 |
|
39
|
|
|
else: |
|
40
|
|
|
assert False |
|
41
|
|
|
mo = re.match('PCAR(\\d+),.*', indexkey) |
|
42
|
|
|
if mo: |
|
43
|
|
|
return bits_per_d * int(mo.group(1)) / 8 |
|
|
|
|
|
|
44
|
|
|
mo = re.match('OPQ\\d+_(\\d+),.*', indexkey) |
|
45
|
|
|
if mo: |
|
46
|
|
|
return bits_per_d * int(mo.group(1)) / 8 |
|
47
|
|
|
mo = re.match('RR(\\d+),.*', indexkey) |
|
48
|
|
|
if mo: |
|
49
|
|
|
return bits_per_d * int(mo.group(1)) / 8 |
|
50
|
|
|
assert False |
|
51
|
|
|
|
|
52
|
|
|
|
|
53
|
|
|
def dbsize_from_name(dbname): |
|
54
|
|
|
sufs = { |
|
55
|
|
|
'1B': 10**9, |
|
56
|
|
|
'100M': 10**8, |
|
57
|
|
|
'10M': 10**7, |
|
58
|
|
|
'1M': 10**6, |
|
59
|
|
|
} |
|
60
|
|
|
for s in sufs: |
|
61
|
|
|
if dbname.endswith(s): |
|
62
|
|
|
return sufs[s] |
|
63
|
|
|
else: |
|
64
|
|
|
assert False |
|
65
|
|
|
|
|
66
|
|
|
|
|
67
|
|
|
def keep_latest_stdout(fnames): |
|
68
|
|
|
fnames = [fname for fname in fnames if fname.endswith('.stdout')] |
|
69
|
|
|
fnames.sort() |
|
70
|
|
|
n = len(fnames) |
|
71
|
|
|
fnames2 = [] |
|
72
|
|
|
for i, fname in enumerate(fnames): |
|
73
|
|
|
if i + 1 < n and fnames[i + 1][:-8] == fname[:-8]: |
|
74
|
|
|
continue |
|
75
|
|
|
fnames2.append(fname) |
|
76
|
|
|
return fnames2 |
|
77
|
|
|
|
|
78
|
|
|
|
|
79
|
|
|
def parse_result_file(fname): |
|
80
|
|
|
# print fname |
|
81
|
|
|
st = 0 |
|
82
|
|
|
res = [] |
|
83
|
|
|
keys = [] |
|
84
|
|
|
stats = {} |
|
85
|
|
|
stats['run_version'] = fname[-8] |
|
86
|
|
|
for l in open(fname): |
|
87
|
|
|
if st == 0: |
|
88
|
|
|
if l.startswith('CHRONOS_JOB_INSTANCE_ID'): |
|
89
|
|
|
stats['CHRONOS_JOB_INSTANCE_ID'] = l.split()[-1] |
|
90
|
|
|
if l.startswith('index size on disk:'): |
|
91
|
|
|
stats['index_size'] = int(l.split()[-1]) |
|
92
|
|
|
if l.startswith('current RSS:'): |
|
93
|
|
|
stats['RSS'] = int(l.split()[-1]) |
|
94
|
|
|
if l.startswith('precomputed tables size:'): |
|
95
|
|
|
stats['tables_size'] = int(l.split()[-1]) |
|
96
|
|
|
if l.startswith('Setting nb of threads to'): |
|
97
|
|
|
stats['n_threads'] = int(l.split()[-1]) |
|
98
|
|
|
if l.startswith(' add in'): |
|
99
|
|
|
stats['add_time'] = float(l.split()[-2]) |
|
100
|
|
|
if l.startswith('args:'): |
|
101
|
|
|
args = eval(l[l.find(' '):]) |
|
102
|
|
|
indexkey = args.indexkey |
|
103
|
|
|
elif 'R@1 R@10 R@100' in l: |
|
104
|
|
|
st = 1 |
|
105
|
|
|
elif 'index size on disk:' in l: |
|
106
|
|
|
index_size = int(l.split()[-1]) |
|
107
|
|
|
elif st == 1: |
|
108
|
|
|
st = 2 |
|
109
|
|
|
elif st == 2: |
|
110
|
|
|
fi = l.split() |
|
111
|
|
|
keys.append(fi[0]) |
|
112
|
|
|
res.append([float(x) for x in fi[1:]]) |
|
113
|
|
|
return indexkey, np.array(res), keys, stats |
|
|
|
|
|
|
114
|
|
|
|
|
115
|
|
|
# run parsing |
|
116
|
|
|
allres = {} |
|
117
|
|
|
allstats = {} |
|
118
|
|
|
nts = [] |
|
119
|
|
|
missing = [] |
|
120
|
|
|
versions = {} |
|
121
|
|
|
|
|
122
|
|
|
fnames = keep_latest_stdout(os.listdir(logdir)) |
|
123
|
|
|
# print fnames |
|
124
|
|
|
# filenames are in the form <key>.x.stdout |
|
125
|
|
|
# where x is a version number (from a to z) |
|
126
|
|
|
# keep only latest version of each name |
|
127
|
|
|
|
|
128
|
|
|
for fname in fnames: |
|
129
|
|
|
if not ('db' + db in fname and fname.endswith('.stdout')): |
|
130
|
|
|
continue |
|
131
|
|
|
indexkey, res, _, stats = parse_result_file(logdir + fname) |
|
132
|
|
|
if res.size == 0: |
|
133
|
|
|
missing.append(fname) |
|
134
|
|
|
errorline = open( |
|
135
|
|
|
logdir + fname.replace('.stdout', '.stderr')).readlines() |
|
136
|
|
|
if len(errorline) > 0: |
|
137
|
|
|
errorline = errorline[-1] |
|
138
|
|
|
else: |
|
139
|
|
|
errorline = 'NO STDERR' |
|
140
|
|
|
print fname, stats['CHRONOS_JOB_INSTANCE_ID'], errorline |
|
141
|
|
|
|
|
142
|
|
|
else: |
|
143
|
|
|
if indexkey in allres: |
|
144
|
|
|
if allstats[indexkey]['run_version'] > stats['run_version']: |
|
145
|
|
|
# don't use this run |
|
146
|
|
|
continue |
|
147
|
|
|
n_threads = stats.get('n_threads', 1) |
|
148
|
|
|
nts.append(n_threads) |
|
149
|
|
|
allres[indexkey] = res |
|
150
|
|
|
allstats[indexkey] = stats |
|
151
|
|
|
|
|
152
|
|
|
assert len(set(nts)) == 1 |
|
153
|
|
|
n_threads = nts[0] |
|
154
|
|
|
|
|
155
|
|
|
|
|
156
|
|
|
def plot_tradeoffs(allres, code_size, recall_rank): |
|
157
|
|
|
dbsize = dbsize_from_name(db) |
|
158
|
|
|
recall_idx = int(np.log10(recall_rank)) |
|
159
|
|
|
|
|
160
|
|
|
bigtab = [] |
|
161
|
|
|
names = [] |
|
162
|
|
|
|
|
163
|
|
|
for k,v in sorted(allres.items()): |
|
164
|
|
|
if v.ndim != 2: continue |
|
165
|
|
|
us = unitsize(k) |
|
166
|
|
|
if us != code_size: continue |
|
167
|
|
|
perf = v[:, recall_idx] |
|
168
|
|
|
times = v[:, 3] |
|
169
|
|
|
bigtab.append( |
|
170
|
|
|
np.vstack(( |
|
171
|
|
|
np.ones(times.size, dtype=int) * len(names), |
|
172
|
|
|
perf, times |
|
173
|
|
|
)) |
|
174
|
|
|
) |
|
175
|
|
|
names.append(k) |
|
176
|
|
|
|
|
177
|
|
|
bigtab = np.hstack(bigtab) |
|
178
|
|
|
|
|
179
|
|
|
perm = np.argsort(bigtab[1, :]) |
|
180
|
|
|
bigtab = bigtab[:, perm] |
|
181
|
|
|
|
|
182
|
|
|
times = np.minimum.accumulate(bigtab[2, ::-1])[::-1] |
|
183
|
|
|
selection = np.where(bigtab[2, :] == times) |
|
184
|
|
|
|
|
185
|
|
|
selected_methods = [names[i] for i in |
|
186
|
|
|
np.unique(bigtab[0, selection].astype(int))] |
|
187
|
|
|
not_selected = list(set(names) - set(selected_methods)) |
|
188
|
|
|
|
|
189
|
|
|
print "methods without an optimal OP: ", not_selected |
|
190
|
|
|
|
|
191
|
|
|
nq = 10000 |
|
192
|
|
|
pyplot.title('database ' + db + ' code_size=%d' % code_size) |
|
193
|
|
|
|
|
194
|
|
|
# grayed out lines |
|
195
|
|
|
|
|
196
|
|
|
for k in not_selected: |
|
197
|
|
|
v = allres[k] |
|
198
|
|
|
if v.ndim != 2: continue |
|
199
|
|
|
us = unitsize(k) |
|
200
|
|
|
if us != code_size: continue |
|
201
|
|
|
|
|
202
|
|
|
linestyle = (':' if 'PQ' in k else |
|
203
|
|
|
'-.' if 'SQ4' in k else |
|
204
|
|
|
'--' if 'SQ8' in k else '-') |
|
205
|
|
|
|
|
206
|
|
|
pyplot.semilogy(v[:, recall_idx], v[:, 3], label=None, |
|
207
|
|
|
linestyle=linestyle, |
|
208
|
|
|
marker='o' if 'HNSW' in k else '+', |
|
209
|
|
|
color='#cccccc', linewidth=0.2) |
|
210
|
|
|
|
|
211
|
|
|
# important methods |
|
212
|
|
|
for k in selected_methods: |
|
213
|
|
|
v = allres[k] |
|
214
|
|
|
if v.ndim != 2: continue |
|
215
|
|
|
us = unitsize(k) |
|
216
|
|
|
if us != code_size: continue |
|
217
|
|
|
|
|
218
|
|
|
stats = allstats[k] |
|
219
|
|
|
tot_size = stats['index_size'] + stats['tables_size'] |
|
220
|
|
|
id_size = 8 # 64 bit |
|
221
|
|
|
|
|
222
|
|
|
addt = '' |
|
223
|
|
|
if 'add_time' in stats: |
|
224
|
|
|
add_time = stats['add_time'] |
|
225
|
|
|
if add_time > 7200: |
|
226
|
|
|
add_min = add_time / 60 |
|
227
|
|
|
addt = ', %dh%02d' % (add_min / 60, add_min % 60) |
|
228
|
|
|
else: |
|
229
|
|
|
add_sec = int(add_time) |
|
230
|
|
|
addt = ', %dm%02d' % (add_sec / 60, add_sec % 60) |
|
231
|
|
|
|
|
232
|
|
|
|
|
233
|
|
|
label = k + ' (size+%.1f%%%s)' % ( |
|
234
|
|
|
tot_size / float((code_size + id_size) * dbsize) * 100 - 100, |
|
235
|
|
|
addt) |
|
236
|
|
|
|
|
237
|
|
|
linestyle = (':' if 'PQ' in k else |
|
238
|
|
|
'-.' if 'SQ4' in k else |
|
239
|
|
|
'--' if 'SQ8' in k else '-') |
|
240
|
|
|
|
|
241
|
|
|
pyplot.semilogy(v[:, recall_idx], v[:, 3], label=label, |
|
242
|
|
|
linestyle=linestyle, |
|
243
|
|
|
marker='o' if 'HNSW' in k else '+') |
|
244
|
|
|
|
|
245
|
|
|
if len(not_selected) == 0: |
|
246
|
|
|
om = '' |
|
247
|
|
|
else: |
|
248
|
|
|
om = '\nomitted:' |
|
249
|
|
|
nc = len(om) |
|
250
|
|
|
for m in not_selected: |
|
251
|
|
|
if nc > 80: |
|
252
|
|
|
om += '\n' |
|
253
|
|
|
nc = 0 |
|
254
|
|
|
om += ' ' + m |
|
255
|
|
|
nc += len(m) + 1 |
|
256
|
|
|
|
|
257
|
|
|
pyplot.xlabel('1-recall at %d %s' % (recall_rank, om) ) |
|
258
|
|
|
pyplot.ylabel('search time per query (ms, %d threads)' % n_threads) |
|
259
|
|
|
pyplot.legend() |
|
260
|
|
|
pyplot.grid() |
|
261
|
|
|
pyplot.savefig('figs/tradeoffs_%s_cs%d_r%d.png' % ( |
|
262
|
|
|
db, code_size, recall_rank)) |
|
263
|
|
|
return selected_methods, not_selected |
|
264
|
|
|
|
|
265
|
|
|
|
|
266
|
|
|
pyplot.gcf().set_size_inches(15, 10) |
|
267
|
|
|
|
|
268
|
|
|
plot_tradeoffs(allres, code_size=code_size, recall_rank=1) |
|
269
|
|
|
|