parse_bench_all_ivf.plot_tradeoffs()   F
last analyzed

Complexity

Conditions 23

Size

Total Lines 108
Code Lines 81

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 23
eloc 81
nop 3
dl 0
loc 108
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like parse_bench_all_ivf.plot_tradeoffs() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright (c) Facebook, Inc. and its affiliates.
2
#
3
# This source code is licensed under the MIT license found in the
4
# LICENSE file in the root directory of this source tree.
5
6
#! /usr/bin/env python2
7
8
import os
9
import numpy as np
10
from matplotlib import pyplot
11
12
import re
13
14
from argparse import Namespace
15
16
17
# the directory used in run_on_cluster.bash
18
basedir = '/mnt/vol/gfsai-east/ai-group/users/matthijs/bench_all_ivf/'
19
logdir = basedir + 'logs/'
20
21
22
# which plot to output
23
db = 'bigann1B'
24
code_size = 8
25
26
27
28
def unitsize(indexkey):
29
    """ size of one vector in the index """
30
    mo = re.match('.*,PQ(\\d+)', indexkey)
31
    if mo:
32
        return int(mo.group(1))
33
    if indexkey.endswith('SQ8'):
34
        bits_per_d = 8
35
    elif indexkey.endswith('SQ4'):
36
        bits_per_d = 4
37
    elif indexkey.endswith('SQfp16'):
38
        bits_per_d = 16
39
    else:
40
        assert False
41
    mo = re.match('PCAR(\\d+),.*', indexkey)
42
    if mo:
43
        return bits_per_d * int(mo.group(1)) / 8
0 ignored issues
show
introduced by
The variable bits_per_d does not seem to be defined for all execution paths.
Loading history...
44
    mo = re.match('OPQ\\d+_(\\d+),.*', indexkey)
45
    if mo:
46
        return bits_per_d * int(mo.group(1)) / 8
47
    mo = re.match('RR(\\d+),.*', indexkey)
48
    if mo:
49
        return bits_per_d * int(mo.group(1)) / 8
50
    assert False
51
52
53
def dbsize_from_name(dbname):
54
    sufs = {
55
        '1B': 10**9,
56
        '100M': 10**8,
57
        '10M': 10**7,
58
        '1M': 10**6,
59
    }
60
    for s in sufs:
61
        if dbname.endswith(s):
62
            return sufs[s]
63
    else:
64
        assert False
65
66
67
def keep_latest_stdout(fnames):
68
    fnames = [fname for fname in fnames if fname.endswith('.stdout')]
69
    fnames.sort()
70
    n = len(fnames)
71
    fnames2 = []
72
    for i, fname in enumerate(fnames):
73
        if i + 1 < n and fnames[i + 1][:-8] == fname[:-8]:
74
            continue
75
        fnames2.append(fname)
76
    return fnames2
77
78
79
def parse_result_file(fname):
80
    # print fname
81
    st = 0
82
    res = []
83
    keys = []
84
    stats = {}
85
    stats['run_version'] = fname[-8]
86
    for l in open(fname):
87
        if st == 0:
88
            if l.startswith('CHRONOS_JOB_INSTANCE_ID'):
89
                stats['CHRONOS_JOB_INSTANCE_ID'] = l.split()[-1]
90
            if l.startswith('index size on disk:'):
91
                stats['index_size'] = int(l.split()[-1])
92
            if l.startswith('current RSS:'):
93
                stats['RSS'] = int(l.split()[-1])
94
            if l.startswith('precomputed tables size:'):
95
                stats['tables_size'] = int(l.split()[-1])
96
            if l.startswith('Setting nb of threads to'):
97
                stats['n_threads'] = int(l.split()[-1])
98
            if l.startswith('  add in'):
99
                stats['add_time'] = float(l.split()[-2])
100
            if l.startswith('args:'):
101
                args = eval(l[l.find(' '):])
102
                indexkey = args.indexkey
103
            elif 'R@1   R@10  R@100' in l:
104
                st = 1
105
            elif 'index size on disk:' in l:
106
                index_size = int(l.split()[-1])
107
        elif st == 1:
108
            st = 2
109
        elif st == 2:
110
            fi = l.split()
111
            keys.append(fi[0])
112
            res.append([float(x) for x in fi[1:]])
113
    return indexkey, np.array(res), keys, stats
0 ignored issues
show
introduced by
The variable indexkey does not seem to be defined in case the for loop on line 86 is not entered. Are you sure this can never be the case?
Loading history...
114
115
# run parsing
116
allres = {}
117
allstats = {}
118
nts = []
119
missing = []
120
versions = {}
121
122
fnames = keep_latest_stdout(os.listdir(logdir))
123
# print fnames
124
# filenames are in the form <key>.x.stdout
125
# where x is a version number (from a to z)
126
# keep only latest version of each name
127
128
for fname in fnames:
129
    if not ('db' + db in fname and fname.endswith('.stdout')):
130
        continue
131
    indexkey, res, _, stats = parse_result_file(logdir + fname)
132
    if res.size == 0:
133
        missing.append(fname)
134
        errorline = open(
135
            logdir + fname.replace('.stdout', '.stderr')).readlines()
136
        if len(errorline) > 0:
137
            errorline = errorline[-1]
138
        else:
139
            errorline = 'NO STDERR'
140
        print fname, stats['CHRONOS_JOB_INSTANCE_ID'], errorline
141
142
    else:
143
        if indexkey in allres:
144
            if allstats[indexkey]['run_version'] > stats['run_version']:
145
                # don't use this run
146
                continue
147
        n_threads = stats.get('n_threads', 1)
148
        nts.append(n_threads)
149
        allres[indexkey] = res
150
        allstats[indexkey] = stats
151
152
assert len(set(nts)) == 1
153
n_threads = nts[0]
154
155
156
def plot_tradeoffs(allres, code_size, recall_rank):
157
    dbsize = dbsize_from_name(db)
158
    recall_idx = int(np.log10(recall_rank))
159
160
    bigtab = []
161
    names = []
162
163
    for k,v in sorted(allres.items()):
164
        if v.ndim != 2: continue
165
        us = unitsize(k)
166
        if us != code_size: continue
167
        perf = v[:, recall_idx]
168
        times = v[:, 3]
169
        bigtab.append(
170
            np.vstack((
171
                np.ones(times.size, dtype=int) * len(names),
172
                perf, times
173
            ))
174
        )
175
        names.append(k)
176
177
    bigtab = np.hstack(bigtab)
178
179
    perm = np.argsort(bigtab[1, :])
180
    bigtab = bigtab[:, perm]
181
182
    times = np.minimum.accumulate(bigtab[2, ::-1])[::-1]
183
    selection = np.where(bigtab[2, :] == times)
184
185
    selected_methods = [names[i] for i in
186
                        np.unique(bigtab[0, selection].astype(int))]
187
    not_selected = list(set(names) - set(selected_methods))
188
189
    print "methods without an optimal OP: ", not_selected
190
191
    nq = 10000
192
    pyplot.title('database ' + db + ' code_size=%d' % code_size)
193
194
    # grayed out lines
195
196
    for k in not_selected:
197
        v = allres[k]
198
        if v.ndim != 2: continue
199
        us = unitsize(k)
200
        if us != code_size: continue
201
202
        linestyle = (':' if 'PQ' in k else
203
                     '-.' if 'SQ4' in k else
204
                     '--' if 'SQ8' in k else '-')
205
206
        pyplot.semilogy(v[:, recall_idx], v[:, 3], label=None,
207
                        linestyle=linestyle,
208
                        marker='o' if 'HNSW' in k else '+',
209
                        color='#cccccc', linewidth=0.2)
210
211
    # important methods
212
    for k in selected_methods:
213
        v = allres[k]
214
        if v.ndim != 2: continue
215
        us = unitsize(k)
216
        if us != code_size: continue
217
218
        stats = allstats[k]
219
        tot_size = stats['index_size'] + stats['tables_size']
220
        id_size = 8 # 64 bit
221
222
        addt = ''
223
        if 'add_time' in stats:
224
            add_time = stats['add_time']
225
            if add_time > 7200:
226
                add_min = add_time / 60
227
                addt = ', %dh%02d' % (add_min / 60, add_min % 60)
228
            else:
229
                add_sec = int(add_time)
230
                addt = ', %dm%02d' % (add_sec / 60, add_sec % 60)
231
232
233
        label = k + ' (size+%.1f%%%s)' % (
234
            tot_size / float((code_size + id_size) * dbsize) * 100 - 100,
235
            addt)
236
237
        linestyle = (':' if 'PQ' in k else
238
                     '-.' if 'SQ4' in k else
239
                     '--' if 'SQ8' in k else '-')
240
241
        pyplot.semilogy(v[:, recall_idx], v[:, 3], label=label,
242
                        linestyle=linestyle,
243
                        marker='o' if 'HNSW' in k else '+')
244
245
    if len(not_selected) == 0:
246
        om = ''
247
    else:
248
        om = '\nomitted:'
249
        nc = len(om)
250
        for m in not_selected:
251
            if nc > 80:
252
                om += '\n'
253
                nc = 0
254
            om += ' ' + m
255
            nc += len(m) + 1
256
257
    pyplot.xlabel('1-recall at %d %s' % (recall_rank, om) )
258
    pyplot.ylabel('search time per query (ms, %d threads)' % n_threads)
259
    pyplot.legend()
260
    pyplot.grid()
261
    pyplot.savefig('figs/tradeoffs_%s_cs%d_r%d.png' % (
262
        db, code_size, recall_rank))
263
    return selected_methods, not_selected
264
265
266
pyplot.gcf().set_size_inches(15, 10)
267
268
plot_tradeoffs(allres, code_size=code_size, recall_rank=1)
269