|
1
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates. |
|
2
|
|
|
# |
|
3
|
|
|
# This source code is licensed under the MIT license found in the |
|
4
|
|
|
# LICENSE file in the root directory of this source tree. |
|
5
|
|
|
|
|
6
|
|
|
#@nolint |
|
7
|
|
|
|
|
8
|
|
|
# not linting this file because it imports * form swigfaiss, which |
|
9
|
|
|
# causes a ton of useless warnings. |
|
10
|
|
|
|
|
11
|
|
|
import numpy as np |
|
12
|
|
|
import sys |
|
13
|
|
|
import inspect |
|
14
|
|
|
import pdb |
|
15
|
|
|
import platform |
|
16
|
|
|
import subprocess |
|
17
|
|
|
import logging |
|
18
|
|
|
|
|
19
|
|
|
|
|
20
|
|
|
logger = logging.getLogger(__name__) |
|
21
|
|
|
|
|
22
|
|
|
|
|
23
|
|
|
def instruction_set(): |
|
24
|
|
|
if platform.system() == "Darwin": |
|
25
|
|
|
if subprocess.check_output(["/usr/sbin/sysctl", "hw.optional.avx2_0"])[-1] == '1': |
|
26
|
|
|
return "AVX2" |
|
27
|
|
|
else: |
|
28
|
|
|
return "default" |
|
29
|
|
|
elif platform.system() == "Linux": |
|
30
|
|
|
import numpy.distutils.cpuinfo |
|
31
|
|
|
if "avx2" in numpy.distutils.cpuinfo.cpu.info[0].get('flags', ""): |
|
32
|
|
|
return "AVX2" |
|
33
|
|
|
else: |
|
34
|
|
|
return "default" |
|
35
|
|
|
|
|
36
|
|
|
|
|
37
|
|
|
try: |
|
38
|
|
|
instr_set = instruction_set() |
|
39
|
|
|
if instr_set == "AVX2": |
|
40
|
|
|
logger.info("Loading faiss with AVX2 support.") |
|
41
|
|
|
from .swigfaiss_avx2 import * |
|
42
|
|
|
else: |
|
43
|
|
|
logger.info("Loading faiss.") |
|
44
|
|
|
from .swigfaiss import * |
|
45
|
|
|
|
|
46
|
|
|
except ImportError: |
|
47
|
|
|
# we import * so that the symbol X can be accessed as faiss.X |
|
48
|
|
|
logger.info("Loading faiss.") |
|
49
|
|
|
from .swigfaiss import * |
|
50
|
|
|
|
|
51
|
|
|
|
|
52
|
|
|
__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR, |
|
|
|
|
|
|
53
|
|
|
FAISS_VERSION_MINOR, |
|
|
|
|
|
|
54
|
|
|
FAISS_VERSION_PATCH) |
|
|
|
|
|
|
55
|
|
|
|
|
56
|
|
|
################################################################## |
|
57
|
|
|
# The functions below add or replace some methods for classes |
|
58
|
|
|
# this is to be able to pass in numpy arrays directly |
|
59
|
|
|
# The C++ version of the classnames will be suffixed with _c |
|
60
|
|
|
################################################################## |
|
61
|
|
|
|
|
62
|
|
|
|
|
63
|
|
|
def replace_method(the_class, name, replacement, ignore_missing=False): |
|
64
|
|
|
try: |
|
65
|
|
|
orig_method = getattr(the_class, name) |
|
66
|
|
|
except AttributeError: |
|
67
|
|
|
if ignore_missing: |
|
68
|
|
|
return |
|
69
|
|
|
raise |
|
70
|
|
|
if orig_method.__name__ == 'replacement_' + name: |
|
71
|
|
|
# replacement was done in parent class |
|
72
|
|
|
return |
|
73
|
|
|
setattr(the_class, name + '_c', orig_method) |
|
74
|
|
|
setattr(the_class, name, replacement) |
|
75
|
|
|
|
|
76
|
|
|
|
|
77
|
|
|
def handle_Clustering(): |
|
78
|
|
|
def replacement_train(self, x, index, weights=None): |
|
79
|
|
|
n, d = x.shape |
|
80
|
|
|
assert d == self.d |
|
81
|
|
|
if weights is not None: |
|
82
|
|
|
assert weights.shape == (n, ) |
|
83
|
|
|
self.train_c(n, swig_ptr(x), index, swig_ptr(weights)) |
|
|
|
|
|
|
84
|
|
|
else: |
|
85
|
|
|
self.train_c(n, swig_ptr(x), index) |
|
86
|
|
|
def replacement_train_encoded(self, x, codec, index, weights=None): |
|
87
|
|
|
n, d = x.shape |
|
88
|
|
|
assert d == codec.sa_code_size() |
|
89
|
|
|
assert codec.d == index.d |
|
90
|
|
|
if weights is not None: |
|
91
|
|
|
assert weights.shape == (n, ) |
|
92
|
|
|
self.train_encoded_c(n, swig_ptr(x), codec, index, swig_ptr(weights)) |
|
|
|
|
|
|
93
|
|
|
else: |
|
94
|
|
|
self.train_encoded_c(n, swig_ptr(x), codec, index) |
|
95
|
|
|
replace_method(Clustering, 'train', replacement_train) |
|
|
|
|
|
|
96
|
|
|
replace_method(Clustering, 'train_encoded', replacement_train_encoded) |
|
97
|
|
|
|
|
98
|
|
|
|
|
99
|
|
|
handle_Clustering() |
|
100
|
|
|
|
|
101
|
|
|
|
|
102
|
|
|
def handle_Quantizer(the_class): |
|
103
|
|
|
|
|
104
|
|
|
def replacement_train(self, x): |
|
105
|
|
|
n, d = x.shape |
|
106
|
|
|
assert d == self.d |
|
107
|
|
|
self.train_c(n, swig_ptr(x)) |
|
|
|
|
|
|
108
|
|
|
|
|
109
|
|
|
def replacement_compute_codes(self, x): |
|
110
|
|
|
n, d = x.shape |
|
111
|
|
|
assert d == self.d |
|
112
|
|
|
codes = np.empty((n, self.code_size), dtype='uint8') |
|
113
|
|
|
self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n) |
|
|
|
|
|
|
114
|
|
|
return codes |
|
115
|
|
|
|
|
116
|
|
|
def replacement_decode(self, codes): |
|
117
|
|
|
n, cs = codes.shape |
|
118
|
|
|
assert cs == self.code_size |
|
119
|
|
|
x = np.empty((n, self.d), dtype='float32') |
|
120
|
|
|
self.decode_c(swig_ptr(codes), swig_ptr(x), n) |
|
|
|
|
|
|
121
|
|
|
return x |
|
122
|
|
|
|
|
123
|
|
|
replace_method(the_class, 'train', replacement_train) |
|
124
|
|
|
replace_method(the_class, 'compute_codes', replacement_compute_codes) |
|
125
|
|
|
replace_method(the_class, 'decode', replacement_decode) |
|
126
|
|
|
|
|
127
|
|
|
|
|
128
|
|
|
handle_Quantizer(ProductQuantizer) |
|
|
|
|
|
|
129
|
|
|
handle_Quantizer(ScalarQuantizer) |
|
|
|
|
|
|
130
|
|
|
|
|
131
|
|
|
|
|
132
|
|
|
def handle_Index(the_class): |
|
133
|
|
|
|
|
134
|
|
|
def replacement_add(self, x): |
|
135
|
|
|
assert x.flags.contiguous |
|
136
|
|
|
n, d = x.shape |
|
137
|
|
|
assert d == self.d |
|
138
|
|
|
self.add_c(n, swig_ptr(x)) |
|
|
|
|
|
|
139
|
|
|
|
|
140
|
|
|
def replacement_add_with_ids(self, x, ids): |
|
141
|
|
|
n, d = x.shape |
|
142
|
|
|
assert d == self.d |
|
143
|
|
|
assert ids.shape == (n, ), 'not same nb of vectors as ids' |
|
144
|
|
|
self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) |
|
|
|
|
|
|
145
|
|
|
|
|
146
|
|
|
def replacement_assign(self, x, k): |
|
147
|
|
|
n, d = x.shape |
|
148
|
|
|
assert d == self.d |
|
149
|
|
|
labels = np.empty((n, k), dtype=np.int64) |
|
150
|
|
|
self.assign_c(n, swig_ptr(x), swig_ptr(labels), k) |
|
|
|
|
|
|
151
|
|
|
return labels |
|
152
|
|
|
|
|
153
|
|
|
def replacement_train(self, x): |
|
154
|
|
|
assert x.flags.contiguous |
|
155
|
|
|
n, d = x.shape |
|
156
|
|
|
assert d == self.d |
|
157
|
|
|
self.train_c(n, swig_ptr(x)) |
|
|
|
|
|
|
158
|
|
|
|
|
159
|
|
|
def replacement_search(self, x, k): |
|
160
|
|
|
n, d = x.shape |
|
161
|
|
|
assert d == self.d |
|
162
|
|
|
distances = np.empty((n, k), dtype=np.float32) |
|
163
|
|
|
labels = np.empty((n, k), dtype=np.int64) |
|
164
|
|
|
self.search_c(n, swig_ptr(x), |
|
|
|
|
|
|
165
|
|
|
k, swig_ptr(distances), |
|
166
|
|
|
swig_ptr(labels)) |
|
167
|
|
|
return distances, labels |
|
168
|
|
|
|
|
169
|
|
|
def replacement_search_and_reconstruct(self, x, k): |
|
170
|
|
|
n, d = x.shape |
|
171
|
|
|
assert d == self.d |
|
172
|
|
|
distances = np.empty((n, k), dtype=np.float32) |
|
173
|
|
|
labels = np.empty((n, k), dtype=np.int64) |
|
174
|
|
|
recons = np.empty((n, k, d), dtype=np.float32) |
|
175
|
|
|
self.search_and_reconstruct_c(n, swig_ptr(x), |
|
|
|
|
|
|
176
|
|
|
k, swig_ptr(distances), |
|
177
|
|
|
swig_ptr(labels), |
|
178
|
|
|
swig_ptr(recons)) |
|
179
|
|
|
return distances, labels, recons |
|
180
|
|
|
|
|
181
|
|
|
def replacement_remove_ids(self, x): |
|
182
|
|
|
if isinstance(x, IDSelector): |
|
|
|
|
|
|
183
|
|
|
sel = x |
|
184
|
|
|
else: |
|
185
|
|
|
assert x.ndim == 1 |
|
186
|
|
|
index_ivf = try_extract_index_ivf (self) |
|
|
|
|
|
|
187
|
|
|
if index_ivf and index_ivf.direct_map.type == DirectMap.Hashtable: |
|
|
|
|
|
|
188
|
|
|
sel = IDSelectorArray(x.size, swig_ptr(x)) |
|
|
|
|
|
|
189
|
|
|
else: |
|
190
|
|
|
sel = IDSelectorBatch(x.size, swig_ptr(x)) |
|
|
|
|
|
|
191
|
|
|
return self.remove_ids_c(sel) |
|
192
|
|
|
|
|
193
|
|
|
def replacement_reconstruct(self, key): |
|
194
|
|
|
x = np.empty(self.d, dtype=np.float32) |
|
195
|
|
|
self.reconstruct_c(key, swig_ptr(x)) |
|
|
|
|
|
|
196
|
|
|
return x |
|
197
|
|
|
|
|
198
|
|
|
def replacement_reconstruct_n(self, n0, ni): |
|
199
|
|
|
x = np.empty((ni, self.d), dtype=np.float32) |
|
200
|
|
|
self.reconstruct_n_c(n0, ni, swig_ptr(x)) |
|
|
|
|
|
|
201
|
|
|
return x |
|
202
|
|
|
|
|
203
|
|
|
def replacement_update_vectors(self, keys, x): |
|
204
|
|
|
n = keys.size |
|
205
|
|
|
assert keys.shape == (n, ) |
|
206
|
|
|
assert x.shape == (n, self.d) |
|
207
|
|
|
self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x)) |
|
|
|
|
|
|
208
|
|
|
|
|
209
|
|
View Code Duplication |
def replacement_range_search(self, x, thresh): |
|
|
|
|
|
|
210
|
|
|
n, d = x.shape |
|
211
|
|
|
assert d == self.d |
|
212
|
|
|
res = RangeSearchResult(n) |
|
|
|
|
|
|
213
|
|
|
self.range_search_c(n, swig_ptr(x), thresh, res) |
|
|
|
|
|
|
214
|
|
|
# get pointers and copy them |
|
215
|
|
|
lims = rev_swig_ptr(res.lims, n + 1).copy() |
|
|
|
|
|
|
216
|
|
|
nd = int(lims[-1]) |
|
217
|
|
|
D = rev_swig_ptr(res.distances, nd).copy() |
|
218
|
|
|
I = rev_swig_ptr(res.labels, nd).copy() |
|
219
|
|
|
return lims, D, I |
|
220
|
|
|
|
|
221
|
|
|
def replacement_sa_encode(self, x): |
|
222
|
|
|
n, d = x.shape |
|
223
|
|
|
assert d == self.d |
|
224
|
|
|
codes = np.empty((n, self.sa_code_size()), dtype='uint8') |
|
225
|
|
|
self.sa_encode_c(n, swig_ptr(x), swig_ptr(codes)) |
|
|
|
|
|
|
226
|
|
|
return codes |
|
227
|
|
|
|
|
228
|
|
|
def replacement_sa_decode(self, codes): |
|
229
|
|
|
n, cs = codes.shape |
|
230
|
|
|
assert cs == self.sa_code_size() |
|
231
|
|
|
x = np.empty((n, self.d), dtype='float32') |
|
232
|
|
|
self.sa_decode_c(n, swig_ptr(codes), swig_ptr(x)) |
|
|
|
|
|
|
233
|
|
|
return x |
|
234
|
|
|
|
|
235
|
|
|
replace_method(the_class, 'add', replacement_add) |
|
236
|
|
|
replace_method(the_class, 'add_with_ids', replacement_add_with_ids) |
|
237
|
|
|
replace_method(the_class, 'assign', replacement_assign) |
|
238
|
|
|
replace_method(the_class, 'train', replacement_train) |
|
239
|
|
|
replace_method(the_class, 'search', replacement_search) |
|
240
|
|
|
replace_method(the_class, 'remove_ids', replacement_remove_ids) |
|
241
|
|
|
replace_method(the_class, 'reconstruct', replacement_reconstruct) |
|
242
|
|
|
replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) |
|
243
|
|
|
replace_method(the_class, 'range_search', replacement_range_search) |
|
244
|
|
|
replace_method(the_class, 'update_vectors', replacement_update_vectors, |
|
245
|
|
|
ignore_missing=True) |
|
246
|
|
|
replace_method(the_class, 'search_and_reconstruct', |
|
247
|
|
|
replacement_search_and_reconstruct, ignore_missing=True) |
|
248
|
|
|
replace_method(the_class, 'sa_encode', replacement_sa_encode) |
|
249
|
|
|
replace_method(the_class, 'sa_decode', replacement_sa_decode) |
|
250
|
|
|
|
|
251
|
|
|
def handle_IndexBinary(the_class): |
|
252
|
|
|
|
|
253
|
|
|
def replacement_add(self, x): |
|
254
|
|
|
assert x.flags.contiguous |
|
255
|
|
|
n, d = x.shape |
|
256
|
|
|
assert d * 8 == self.d |
|
257
|
|
|
self.add_c(n, swig_ptr(x)) |
|
|
|
|
|
|
258
|
|
|
|
|
259
|
|
|
def replacement_add_with_ids(self, x, ids): |
|
260
|
|
|
n, d = x.shape |
|
261
|
|
|
assert d * 8 == self.d |
|
262
|
|
|
assert ids.shape == (n, ), 'not same nb of vectors as ids' |
|
263
|
|
|
self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) |
|
|
|
|
|
|
264
|
|
|
|
|
265
|
|
|
def replacement_train(self, x): |
|
266
|
|
|
assert x.flags.contiguous |
|
267
|
|
|
n, d = x.shape |
|
268
|
|
|
assert d * 8 == self.d |
|
269
|
|
|
self.train_c(n, swig_ptr(x)) |
|
|
|
|
|
|
270
|
|
|
|
|
271
|
|
|
def replacement_reconstruct(self, key): |
|
272
|
|
|
x = np.empty(self.d // 8, dtype=np.uint8) |
|
273
|
|
|
self.reconstruct_c(key, swig_ptr(x)) |
|
|
|
|
|
|
274
|
|
|
return x |
|
275
|
|
|
|
|
276
|
|
|
def replacement_search(self, x, k): |
|
277
|
|
|
n, d = x.shape |
|
278
|
|
|
assert d * 8 == self.d |
|
279
|
|
|
distances = np.empty((n, k), dtype=np.int32) |
|
280
|
|
|
labels = np.empty((n, k), dtype=np.int64) |
|
281
|
|
|
self.search_c(n, swig_ptr(x), |
|
|
|
|
|
|
282
|
|
|
k, swig_ptr(distances), |
|
283
|
|
|
swig_ptr(labels)) |
|
284
|
|
|
return distances, labels |
|
285
|
|
|
|
|
286
|
|
View Code Duplication |
def replacement_range_search(self, x, thresh): |
|
|
|
|
|
|
287
|
|
|
n, d = x.shape |
|
288
|
|
|
assert d * 8 == self.d |
|
289
|
|
|
res = RangeSearchResult(n) |
|
|
|
|
|
|
290
|
|
|
self.range_search_c(n, swig_ptr(x), thresh, res) |
|
|
|
|
|
|
291
|
|
|
# get pointers and copy them |
|
292
|
|
|
lims = rev_swig_ptr(res.lims, n + 1).copy() |
|
|
|
|
|
|
293
|
|
|
nd = int(lims[-1]) |
|
294
|
|
|
D = rev_swig_ptr(res.distances, nd).copy() |
|
295
|
|
|
I = rev_swig_ptr(res.labels, nd).copy() |
|
296
|
|
|
return lims, D, I |
|
297
|
|
|
|
|
298
|
|
|
def replacement_remove_ids(self, x): |
|
299
|
|
|
if isinstance(x, IDSelector): |
|
|
|
|
|
|
300
|
|
|
sel = x |
|
301
|
|
|
else: |
|
302
|
|
|
assert x.ndim == 1 |
|
303
|
|
|
sel = IDSelectorBatch(x.size, swig_ptr(x)) |
|
|
|
|
|
|
304
|
|
|
return self.remove_ids_c(sel) |
|
305
|
|
|
|
|
306
|
|
|
replace_method(the_class, 'add', replacement_add) |
|
307
|
|
|
replace_method(the_class, 'add_with_ids', replacement_add_with_ids) |
|
308
|
|
|
replace_method(the_class, 'train', replacement_train) |
|
309
|
|
|
replace_method(the_class, 'search', replacement_search) |
|
310
|
|
|
replace_method(the_class, 'range_search', replacement_range_search) |
|
311
|
|
|
replace_method(the_class, 'reconstruct', replacement_reconstruct) |
|
312
|
|
|
replace_method(the_class, 'remove_ids', replacement_remove_ids) |
|
313
|
|
|
|
|
314
|
|
|
|
|
315
|
|
|
def handle_VectorTransform(the_class): |
|
316
|
|
|
|
|
317
|
|
|
def apply_method(self, x): |
|
318
|
|
|
assert x.flags.contiguous |
|
319
|
|
|
n, d = x.shape |
|
320
|
|
|
assert d == self.d_in |
|
321
|
|
|
y = np.empty((n, self.d_out), dtype=np.float32) |
|
322
|
|
|
self.apply_noalloc(n, swig_ptr(x), swig_ptr(y)) |
|
|
|
|
|
|
323
|
|
|
return y |
|
324
|
|
|
|
|
325
|
|
|
def replacement_reverse_transform(self, x): |
|
326
|
|
|
n, d = x.shape |
|
327
|
|
|
assert d == self.d_out |
|
328
|
|
|
y = np.empty((n, self.d_in), dtype=np.float32) |
|
329
|
|
|
self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y)) |
|
|
|
|
|
|
330
|
|
|
return y |
|
331
|
|
|
|
|
332
|
|
|
def replacement_vt_train(self, x): |
|
333
|
|
|
assert x.flags.contiguous |
|
334
|
|
|
n, d = x.shape |
|
335
|
|
|
assert d == self.d_in |
|
336
|
|
|
self.train_c(n, swig_ptr(x)) |
|
|
|
|
|
|
337
|
|
|
|
|
338
|
|
|
replace_method(the_class, 'train', replacement_vt_train) |
|
339
|
|
|
# apply is reserved in Pyton... |
|
340
|
|
|
the_class.apply_py = apply_method |
|
341
|
|
|
replace_method(the_class, 'reverse_transform', |
|
342
|
|
|
replacement_reverse_transform) |
|
343
|
|
|
|
|
344
|
|
|
|
|
345
|
|
|
def handle_AutoTuneCriterion(the_class): |
|
346
|
|
|
def replacement_set_groundtruth(self, D, I): |
|
347
|
|
|
if D: |
|
348
|
|
|
assert I.shape == D.shape |
|
349
|
|
|
self.nq, self.gt_nnn = I.shape |
|
350
|
|
|
self.set_groundtruth_c( |
|
351
|
|
|
self.gt_nnn, swig_ptr(D) if D else None, swig_ptr(I)) |
|
|
|
|
|
|
352
|
|
|
|
|
353
|
|
|
def replacement_evaluate(self, D, I): |
|
354
|
|
|
assert I.shape == D.shape |
|
355
|
|
|
assert I.shape == (self.nq, self.nnn) |
|
356
|
|
|
return self.evaluate_c(swig_ptr(D), swig_ptr(I)) |
|
|
|
|
|
|
357
|
|
|
|
|
358
|
|
|
replace_method(the_class, 'set_groundtruth', replacement_set_groundtruth) |
|
359
|
|
|
replace_method(the_class, 'evaluate', replacement_evaluate) |
|
360
|
|
|
|
|
361
|
|
|
|
|
362
|
|
|
def handle_ParameterSpace(the_class): |
|
363
|
|
|
def replacement_explore(self, index, xq, crit): |
|
364
|
|
|
assert xq.shape == (crit.nq, index.d) |
|
365
|
|
|
ops = OperatingPoints() |
|
|
|
|
|
|
366
|
|
|
self.explore_c(index, crit.nq, swig_ptr(xq), |
|
|
|
|
|
|
367
|
|
|
crit, ops) |
|
368
|
|
|
return ops |
|
369
|
|
|
replace_method(the_class, 'explore', replacement_explore) |
|
370
|
|
|
|
|
371
|
|
|
|
|
372
|
|
|
def handle_MatrixStats(the_class): |
|
373
|
|
|
original_init = the_class.__init__ |
|
374
|
|
|
|
|
375
|
|
|
def replacement_init(self, m): |
|
376
|
|
|
assert len(m.shape) == 2 |
|
377
|
|
|
original_init(self, m.shape[0], m.shape[1], swig_ptr(m)) |
|
|
|
|
|
|
378
|
|
|
|
|
379
|
|
|
the_class.__init__ = replacement_init |
|
380
|
|
|
|
|
381
|
|
|
handle_MatrixStats(MatrixStats) |
|
|
|
|
|
|
382
|
|
|
|
|
383
|
|
|
|
|
384
|
|
|
this_module = sys.modules[__name__] |
|
385
|
|
|
|
|
386
|
|
|
|
|
387
|
|
|
for symbol in dir(this_module): |
|
388
|
|
|
obj = getattr(this_module, symbol) |
|
389
|
|
|
# print symbol, isinstance(obj, (type, types.ClassType)) |
|
390
|
|
|
if inspect.isclass(obj): |
|
391
|
|
|
the_class = obj |
|
392
|
|
|
if issubclass(the_class, Index): |
|
|
|
|
|
|
393
|
|
|
handle_Index(the_class) |
|
394
|
|
|
|
|
395
|
|
|
if issubclass(the_class, IndexBinary): |
|
|
|
|
|
|
396
|
|
|
handle_IndexBinary(the_class) |
|
397
|
|
|
|
|
398
|
|
|
if issubclass(the_class, VectorTransform): |
|
|
|
|
|
|
399
|
|
|
handle_VectorTransform(the_class) |
|
400
|
|
|
|
|
401
|
|
|
if issubclass(the_class, AutoTuneCriterion): |
|
|
|
|
|
|
402
|
|
|
handle_AutoTuneCriterion(the_class) |
|
403
|
|
|
|
|
404
|
|
|
if issubclass(the_class, ParameterSpace): |
|
|
|
|
|
|
405
|
|
|
handle_ParameterSpace(the_class) |
|
406
|
|
|
|
|
407
|
|
|
|
|
408
|
|
|
########################################### |
|
409
|
|
|
# Add Python references to objects |
|
410
|
|
|
# we do this at the Python class wrapper level. |
|
411
|
|
|
########################################### |
|
412
|
|
|
|
|
413
|
|
|
def add_ref_in_constructor(the_class, parameter_no): |
|
414
|
|
|
# adds a reference to parameter parameter_no in self |
|
415
|
|
|
# so that that parameter does not get deallocated before self |
|
416
|
|
|
original_init = the_class.__init__ |
|
417
|
|
|
|
|
418
|
|
|
def replacement_init(self, *args): |
|
419
|
|
|
original_init(self, *args) |
|
420
|
|
|
self.referenced_objects = [args[parameter_no]] |
|
421
|
|
|
|
|
422
|
|
|
def replacement_init_multiple(self, *args): |
|
423
|
|
|
original_init(self, *args) |
|
424
|
|
|
pset = parameter_no[len(args)] |
|
425
|
|
|
self.referenced_objects = [args[no] for no in pset] |
|
426
|
|
|
|
|
427
|
|
|
if type(parameter_no) == dict: |
|
428
|
|
|
# a list of parameters to keep, depending on the number of arguments |
|
429
|
|
|
the_class.__init__ = replacement_init_multiple |
|
430
|
|
|
else: |
|
431
|
|
|
the_class.__init__ = replacement_init |
|
432
|
|
|
|
|
433
|
|
|
def add_ref_in_method(the_class, method_name, parameter_no): |
|
434
|
|
|
original_method = getattr(the_class, method_name) |
|
435
|
|
|
def replacement_method(self, *args): |
|
436
|
|
|
ref = args[parameter_no] |
|
437
|
|
|
if not hasattr(self, 'referenced_objects'): |
|
438
|
|
|
self.referenced_objects = [ref] |
|
439
|
|
|
else: |
|
440
|
|
|
self.referenced_objects.append(ref) |
|
441
|
|
|
return original_method(self, *args) |
|
442
|
|
|
setattr(the_class, method_name, replacement_method) |
|
443
|
|
|
|
|
444
|
|
|
def add_ref_in_function(function_name, parameter_no): |
|
445
|
|
|
# assumes the function returns an object |
|
446
|
|
|
original_function = getattr(this_module, function_name) |
|
447
|
|
|
def replacement_function(*args): |
|
448
|
|
|
result = original_function(*args) |
|
449
|
|
|
ref = args[parameter_no] |
|
450
|
|
|
result.referenced_objects = [ref] |
|
451
|
|
|
return result |
|
452
|
|
|
setattr(this_module, function_name, replacement_function) |
|
453
|
|
|
|
|
454
|
|
|
add_ref_in_constructor(IndexIVFFlat, 0) |
|
|
|
|
|
|
455
|
|
|
add_ref_in_constructor(IndexIVFFlatDedup, 0) |
|
|
|
|
|
|
456
|
|
|
add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]}) |
|
|
|
|
|
|
457
|
|
|
add_ref_in_method(IndexPreTransform, 'prepend_transform', 0) |
|
458
|
|
|
add_ref_in_constructor(IndexIVFPQ, 0) |
|
|
|
|
|
|
459
|
|
|
add_ref_in_constructor(IndexIVFPQR, 0) |
|
|
|
|
|
|
460
|
|
|
add_ref_in_constructor(Index2Layer, 0) |
|
|
|
|
|
|
461
|
|
|
add_ref_in_constructor(Level1Quantizer, 0) |
|
|
|
|
|
|
462
|
|
|
add_ref_in_constructor(IndexIVFScalarQuantizer, 0) |
|
|
|
|
|
|
463
|
|
|
add_ref_in_constructor(IndexIDMap, 0) |
|
|
|
|
|
|
464
|
|
|
add_ref_in_constructor(IndexIDMap2, 0) |
|
|
|
|
|
|
465
|
|
|
add_ref_in_constructor(IndexHNSW, 0) |
|
|
|
|
|
|
466
|
|
|
add_ref_in_method(IndexShards, 'add_shard', 0) |
|
|
|
|
|
|
467
|
|
|
add_ref_in_method(IndexBinaryShards, 'add_shard', 0) |
|
|
|
|
|
|
468
|
|
|
add_ref_in_constructor(IndexRefineFlat, 0) |
|
|
|
|
|
|
469
|
|
|
add_ref_in_constructor(IndexBinaryIVF, 0) |
|
|
|
|
|
|
470
|
|
|
add_ref_in_constructor(IndexBinaryFromFloat, 0) |
|
|
|
|
|
|
471
|
|
|
add_ref_in_constructor(IndexBinaryIDMap, 0) |
|
|
|
|
|
|
472
|
|
|
add_ref_in_constructor(IndexBinaryIDMap2, 0) |
|
|
|
|
|
|
473
|
|
|
|
|
474
|
|
|
add_ref_in_method(IndexReplicas, 'addIndex', 0) |
|
|
|
|
|
|
475
|
|
|
add_ref_in_method(IndexBinaryReplicas, 'addIndex', 0) |
|
|
|
|
|
|
476
|
|
|
|
|
477
|
|
|
add_ref_in_constructor(BufferedIOWriter, 0) |
|
|
|
|
|
|
478
|
|
|
add_ref_in_constructor(BufferedIOReader, 0) |
|
|
|
|
|
|
479
|
|
|
|
|
480
|
|
|
# seems really marginal... |
|
481
|
|
|
# remove_ref_from_method(IndexReplicas, 'removeIndex', 0) |
|
482
|
|
|
|
|
483
|
|
|
if hasattr(this_module, 'GpuIndexFlat'): |
|
484
|
|
|
# handle all the GPUResources refs |
|
485
|
|
|
add_ref_in_function('index_cpu_to_gpu', 0) |
|
486
|
|
|
add_ref_in_constructor(GpuIndexFlat, 0) |
|
|
|
|
|
|
487
|
|
|
add_ref_in_constructor(GpuIndexFlatIP, 0) |
|
|
|
|
|
|
488
|
|
|
add_ref_in_constructor(GpuIndexFlatL2, 0) |
|
|
|
|
|
|
489
|
|
|
add_ref_in_constructor(GpuIndexIVFFlat, 0) |
|
|
|
|
|
|
490
|
|
|
add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 0) |
|
|
|
|
|
|
491
|
|
|
add_ref_in_constructor(GpuIndexIVFPQ, 0) |
|
|
|
|
|
|
492
|
|
|
add_ref_in_constructor(GpuIndexBinaryFlat, 0) |
|
|
|
|
|
|
493
|
|
|
|
|
494
|
|
|
|
|
495
|
|
|
|
|
496
|
|
|
########################################### |
|
497
|
|
|
# GPU functions |
|
498
|
|
|
########################################### |
|
499
|
|
|
|
|
500
|
|
|
|
|
501
|
|
|
def index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None): |
|
502
|
|
|
""" builds the C++ vectors for the GPU indices and the |
|
503
|
|
|
resources. Handles the case where the resources are assigned to |
|
504
|
|
|
the list of GPUs """ |
|
505
|
|
|
if gpus is None: |
|
506
|
|
|
gpus = range(len(resources)) |
|
507
|
|
|
vres = GpuResourcesVector() |
|
|
|
|
|
|
508
|
|
|
vdev = IntVector() |
|
|
|
|
|
|
509
|
|
|
for i, res in zip(gpus, resources): |
|
510
|
|
|
vdev.push_back(i) |
|
511
|
|
|
vres.push_back(res) |
|
512
|
|
|
index = index_cpu_to_gpu_multiple(vres, vdev, index, co) |
|
|
|
|
|
|
513
|
|
|
index.referenced_objects = resources |
|
514
|
|
|
return index |
|
515
|
|
|
|
|
516
|
|
|
|
|
517
|
|
|
def index_cpu_to_all_gpus(index, co=None, ngpu=-1): |
|
518
|
|
|
index_gpu = index_cpu_to_gpus_list(index, co=co, gpus=None, ngpu=ngpu) |
|
519
|
|
|
return index_gpu |
|
520
|
|
|
|
|
521
|
|
|
|
|
522
|
|
|
def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1): |
|
523
|
|
|
""" Here we can pass list of GPU ids as a parameter or ngpu to |
|
524
|
|
|
use first n GPU's. gpus mut be a list or None""" |
|
525
|
|
|
if (gpus is None) and (ngpu == -1): # All blank |
|
526
|
|
|
gpus = range(get_num_gpus()) |
|
|
|
|
|
|
527
|
|
|
elif (gpus is None) and (ngpu != -1): # Get number of GPU's only |
|
528
|
|
|
gpus = range(ngpu) |
|
529
|
|
|
res = [StandardGpuResources() for _ in gpus] |
|
|
|
|
|
|
530
|
|
|
index_gpu = index_cpu_to_gpu_multiple_py(res, index, co, gpus) |
|
531
|
|
|
return index_gpu |
|
532
|
|
|
|
|
533
|
|
|
|
|
534
|
|
|
########################################### |
|
535
|
|
|
# numpy array / std::vector conversions |
|
536
|
|
|
########################################### |
|
537
|
|
|
|
|
538
|
|
|
# mapping from vector names in swigfaiss.swig and the numpy dtype names |
|
539
|
|
|
vector_name_map = { |
|
540
|
|
|
'Float': 'float32', |
|
541
|
|
|
'Byte': 'uint8', |
|
542
|
|
|
'Char': 'int8', |
|
543
|
|
|
'Uint64': 'uint64', |
|
544
|
|
|
'Long': 'int64', |
|
545
|
|
|
'Int': 'int32', |
|
546
|
|
|
'Double': 'float64' |
|
547
|
|
|
} |
|
548
|
|
|
|
|
549
|
|
|
def vector_to_array(v): |
|
550
|
|
|
""" convert a C++ vector to a numpy array """ |
|
551
|
|
|
classname = v.__class__.__name__ |
|
552
|
|
|
assert classname.endswith('Vector') |
|
553
|
|
|
dtype = np.dtype(vector_name_map[classname[:-6]]) |
|
554
|
|
|
a = np.empty(v.size(), dtype=dtype) |
|
555
|
|
|
if v.size() > 0: |
|
556
|
|
|
memcpy(swig_ptr(a), v.data(), a.nbytes) |
|
|
|
|
|
|
557
|
|
|
return a |
|
558
|
|
|
|
|
559
|
|
|
|
|
560
|
|
|
def vector_float_to_array(v): |
|
561
|
|
|
return vector_to_array(v) |
|
562
|
|
|
|
|
563
|
|
|
|
|
564
|
|
|
def copy_array_to_vector(a, v): |
|
565
|
|
|
""" copy a numpy array to a vector """ |
|
566
|
|
|
n, = a.shape |
|
567
|
|
|
classname = v.__class__.__name__ |
|
568
|
|
|
assert classname.endswith('Vector') |
|
569
|
|
|
dtype = np.dtype(vector_name_map[classname[:-6]]) |
|
570
|
|
|
assert dtype == a.dtype, ( |
|
571
|
|
|
'cannot copy a %s array to a %s (should be %s)' % ( |
|
572
|
|
|
a.dtype, classname, dtype)) |
|
573
|
|
|
v.resize(n) |
|
574
|
|
|
if n > 0: |
|
575
|
|
|
memcpy(v.data(), swig_ptr(a), a.nbytes) |
|
|
|
|
|
|
576
|
|
|
|
|
577
|
|
|
|
|
578
|
|
|
########################################### |
|
579
|
|
|
# Wrapper for a few functions |
|
580
|
|
|
########################################### |
|
581
|
|
|
|
|
582
|
|
View Code Duplication |
def kmin(array, k): |
|
|
|
|
|
|
583
|
|
|
"""return k smallest values (and their indices) of the lines of a |
|
584
|
|
|
float32 array""" |
|
585
|
|
|
m, n = array.shape |
|
586
|
|
|
I = np.zeros((m, k), dtype='int64') |
|
587
|
|
|
D = np.zeros((m, k), dtype='float32') |
|
588
|
|
|
ha = float_maxheap_array_t() |
|
|
|
|
|
|
589
|
|
|
ha.ids = swig_ptr(I) |
|
|
|
|
|
|
590
|
|
|
ha.val = swig_ptr(D) |
|
591
|
|
|
ha.nh = m |
|
592
|
|
|
ha.k = k |
|
593
|
|
|
ha.heapify() |
|
594
|
|
|
ha.addn(n, swig_ptr(array)) |
|
595
|
|
|
ha.reorder() |
|
596
|
|
|
return D, I |
|
597
|
|
|
|
|
598
|
|
|
|
|
599
|
|
View Code Duplication |
def kmax(array, k): |
|
|
|
|
|
|
600
|
|
|
"""return k largest values (and their indices) of the lines of a |
|
601
|
|
|
float32 array""" |
|
602
|
|
|
m, n = array.shape |
|
603
|
|
|
I = np.zeros((m, k), dtype='int64') |
|
604
|
|
|
D = np.zeros((m, k), dtype='float32') |
|
605
|
|
|
ha = float_minheap_array_t() |
|
|
|
|
|
|
606
|
|
|
ha.ids = swig_ptr(I) |
|
|
|
|
|
|
607
|
|
|
ha.val = swig_ptr(D) |
|
608
|
|
|
ha.nh = m |
|
609
|
|
|
ha.k = k |
|
610
|
|
|
ha.heapify() |
|
611
|
|
|
ha.addn(n, swig_ptr(array)) |
|
612
|
|
|
ha.reorder() |
|
613
|
|
|
return D, I |
|
614
|
|
|
|
|
615
|
|
|
|
|
616
|
|
|
def pairwise_distances(xq, xb, mt=METRIC_L2, metric_arg=0): |
|
617
|
|
|
"""compute the whole pairwise distance matrix between two sets of |
|
618
|
|
|
vectors""" |
|
619
|
|
|
nq, d = xq.shape |
|
620
|
|
|
nb, d2 = xb.shape |
|
621
|
|
|
assert d == d2 |
|
622
|
|
|
dis = np.empty((nq, nb), dtype='float32') |
|
623
|
|
|
if mt == METRIC_L2: |
|
|
|
|
|
|
624
|
|
|
pairwise_L2sqr( |
|
|
|
|
|
|
625
|
|
|
d, nq, swig_ptr(xq), |
|
|
|
|
|
|
626
|
|
|
nb, swig_ptr(xb), |
|
627
|
|
|
swig_ptr(dis)) |
|
628
|
|
|
else: |
|
629
|
|
|
pairwise_extra_distances( |
|
|
|
|
|
|
630
|
|
|
d, nq, swig_ptr(xq), |
|
631
|
|
|
nb, swig_ptr(xb), |
|
632
|
|
|
mt, metric_arg, |
|
633
|
|
|
swig_ptr(dis)) |
|
634
|
|
|
return dis |
|
635
|
|
|
|
|
636
|
|
|
|
|
637
|
|
|
|
|
638
|
|
|
|
|
639
|
|
|
def rand(n, seed=12345): |
|
640
|
|
|
res = np.empty(n, dtype='float32') |
|
641
|
|
|
float_rand(swig_ptr(res), res.size, seed) |
|
|
|
|
|
|
642
|
|
|
return res |
|
643
|
|
|
|
|
644
|
|
|
|
|
645
|
|
|
def randint(n, seed=12345, vmax=None): |
|
646
|
|
|
res = np.empty(n, dtype='int64') |
|
647
|
|
|
if vmax is None: |
|
648
|
|
|
int64_rand(swig_ptr(res), res.size, seed) |
|
|
|
|
|
|
649
|
|
|
else: |
|
650
|
|
|
int64_rand_max(swig_ptr(res), res.size, vmax, seed) |
|
|
|
|
|
|
651
|
|
|
return res |
|
652
|
|
|
|
|
653
|
|
|
lrand = randint |
|
654
|
|
|
|
|
655
|
|
|
def randn(n, seed=12345): |
|
656
|
|
|
res = np.empty(n, dtype='float32') |
|
657
|
|
|
float_randn(swig_ptr(res), res.size, seed) |
|
|
|
|
|
|
658
|
|
|
return res |
|
659
|
|
|
|
|
660
|
|
|
|
|
661
|
|
|
def eval_intersection(I1, I2): |
|
662
|
|
|
""" size of intersection between each line of two result tables""" |
|
663
|
|
|
n = I1.shape[0] |
|
664
|
|
|
assert I2.shape[0] == n |
|
665
|
|
|
k1, k2 = I1.shape[1], I2.shape[1] |
|
666
|
|
|
ninter = 0 |
|
667
|
|
|
for i in range(n): |
|
668
|
|
|
ninter += ranklist_intersection_size( |
|
|
|
|
|
|
669
|
|
|
k1, swig_ptr(I1[i]), k2, swig_ptr(I2[i])) |
|
|
|
|
|
|
670
|
|
|
return ninter |
|
671
|
|
|
|
|
672
|
|
|
|
|
673
|
|
|
def normalize_L2(x): |
|
674
|
|
|
fvec_renorm_L2(x.shape[1], x.shape[0], swig_ptr(x)) |
|
|
|
|
|
|
675
|
|
|
|
|
676
|
|
|
# MapLong2Long interface |
|
677
|
|
|
|
|
678
|
|
|
def replacement_map_add(self, keys, vals): |
|
679
|
|
|
n, = keys.shape |
|
680
|
|
|
assert (n,) == keys.shape |
|
681
|
|
|
self.add_c(n, swig_ptr(keys), swig_ptr(vals)) |
|
|
|
|
|
|
682
|
|
|
|
|
683
|
|
|
def replacement_map_search_multiple(self, keys): |
|
684
|
|
|
n, = keys.shape |
|
685
|
|
|
vals = np.empty(n, dtype='int64') |
|
686
|
|
|
self.search_multiple_c(n, swig_ptr(keys), swig_ptr(vals)) |
|
|
|
|
|
|
687
|
|
|
return vals |
|
688
|
|
|
|
|
689
|
|
|
replace_method(MapLong2Long, 'add', replacement_map_add) |
|
|
|
|
|
|
690
|
|
|
replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple) |
|
691
|
|
|
|
|
692
|
|
|
|
|
693
|
|
|
########################################### |
|
694
|
|
|
# Kmeans object |
|
695
|
|
|
########################################### |
|
696
|
|
|
|
|
697
|
|
|
|
|
698
|
|
|
class Kmeans: |
|
699
|
|
|
"""shallow wrapper around the Clustering object. The important method |
|
700
|
|
|
is train().""" |
|
701
|
|
|
|
|
702
|
|
|
def __init__(self, d, k, **kwargs): |
|
703
|
|
|
"""d: input dimension, k: nb of centroids. Additional |
|
704
|
|
|
parameters are passed on the ClusteringParameters object, |
|
705
|
|
|
including niter=25, verbose=False, spherical = False |
|
706
|
|
|
""" |
|
707
|
|
|
self.d = d |
|
708
|
|
|
self.k = k |
|
709
|
|
|
self.gpu = False |
|
710
|
|
|
self.cp = ClusteringParameters() |
|
|
|
|
|
|
711
|
|
|
for k, v in kwargs.items(): |
|
712
|
|
|
if k == 'gpu': |
|
713
|
|
|
self.gpu = v |
|
714
|
|
|
else: |
|
715
|
|
|
# if this raises an exception, it means that it is a non-existent field |
|
716
|
|
|
getattr(self.cp, k) |
|
717
|
|
|
setattr(self.cp, k, v) |
|
718
|
|
|
self.centroids = None |
|
719
|
|
|
|
|
720
|
|
|
def train(self, x, weights=None): |
|
721
|
|
|
n, d = x.shape |
|
722
|
|
|
assert d == self.d |
|
723
|
|
|
clus = Clustering(d, self.k, self.cp) |
|
|
|
|
|
|
724
|
|
|
if self.cp.spherical: |
|
725
|
|
|
self.index = IndexFlatIP(d) |
|
|
|
|
|
|
726
|
|
|
else: |
|
727
|
|
|
self.index = IndexFlatL2(d) |
|
|
|
|
|
|
728
|
|
|
if self.gpu: |
|
729
|
|
|
if self.gpu == True: |
|
730
|
|
|
ngpu = -1 |
|
731
|
|
|
else: |
|
732
|
|
|
ngpu = self.gpu |
|
733
|
|
|
self.index = index_cpu_to_all_gpus(self.index, ngpu=ngpu) |
|
734
|
|
|
clus.train(x, self.index, weights) |
|
735
|
|
|
centroids = vector_float_to_array(clus.centroids) |
|
736
|
|
|
self.centroids = centroids.reshape(self.k, d) |
|
737
|
|
|
stats = clus.iteration_stats |
|
738
|
|
|
self.obj = np.array([ |
|
739
|
|
|
stats.at(i).obj for i in range(stats.size()) |
|
740
|
|
|
]) |
|
741
|
|
|
return self.obj[-1] if self.obj.size > 0 else 0.0 |
|
742
|
|
|
|
|
743
|
|
|
def assign(self, x): |
|
744
|
|
|
assert self.centroids is not None, "should train before assigning" |
|
745
|
|
|
self.index.reset() |
|
746
|
|
|
self.index.add(self.centroids) |
|
747
|
|
|
D, I = self.index.search(x, 1) |
|
748
|
|
|
return D.ravel(), I.ravel() |
|
749
|
|
|
|
|
750
|
|
|
# IndexProxy was renamed to IndexReplicas, remap the old name for any old code |
|
751
|
|
|
# people may have |
|
752
|
|
|
IndexProxy = IndexReplicas |
|
753
|
|
|
ConcatenatedInvertedLists = HStackInvertedLists |
|
|
|
|
|
|
754
|
|
|
|
|
755
|
|
|
########################################### |
|
756
|
|
|
# serialization of indexes to byte arrays |
|
757
|
|
|
########################################### |
|
758
|
|
|
|
|
759
|
|
|
def serialize_index(index): |
|
760
|
|
|
""" convert an index to a numpy uint8 array """ |
|
761
|
|
|
writer = VectorIOWriter() |
|
|
|
|
|
|
762
|
|
|
write_index(index, writer) |
|
|
|
|
|
|
763
|
|
|
return vector_to_array(writer.data) |
|
764
|
|
|
|
|
765
|
|
|
def deserialize_index(data): |
|
766
|
|
|
reader = VectorIOReader() |
|
|
|
|
|
|
767
|
|
|
copy_array_to_vector(data, reader.data) |
|
768
|
|
|
return read_index(reader) |
|
|
|
|
|
|
769
|
|
|
|
|
770
|
|
|
def serialize_index_binary(index): |
|
771
|
|
|
""" convert an index to a numpy uint8 array """ |
|
772
|
|
|
writer = VectorIOWriter() |
|
|
|
|
|
|
773
|
|
|
write_index_binary(index, writer) |
|
|
|
|
|
|
774
|
|
|
return vector_to_array(writer.data) |
|
775
|
|
|
|
|
776
|
|
|
def deserialize_index_binary(data): |
|
777
|
|
|
reader = VectorIOReader() |
|
|
|
|
|
|
778
|
|
|
copy_array_to_vector(data, reader.data) |
|
779
|
|
|
return read_index_binary(reader) |
|
|
|
|
|
|
780
|
|
|
|
|
781
|
|
|
|
|
782
|
|
|
########################################### |
|
783
|
|
|
# ResultHeap |
|
784
|
|
|
########################################### |
|
785
|
|
|
|
|
786
|
|
View Code Duplication |
class ResultHeap: |
|
|
|
|
|
|
787
|
|
|
"""Accumulate query results from a sliced dataset. The final result will |
|
788
|
|
|
be in self.D, self.I.""" |
|
789
|
|
|
|
|
790
|
|
|
def __init__(self, nq, k): |
|
791
|
|
|
" nq: number of query vectors, k: number of results per query " |
|
792
|
|
|
self.I = np.zeros((nq, k), dtype='int64') |
|
793
|
|
|
self.D = np.zeros((nq, k), dtype='float32') |
|
794
|
|
|
self.nq, self.k = nq, k |
|
795
|
|
|
heaps = float_maxheap_array_t() |
|
|
|
|
|
|
796
|
|
|
heaps.k = k |
|
797
|
|
|
heaps.nh = nq |
|
798
|
|
|
heaps.val = swig_ptr(self.D) |
|
|
|
|
|
|
799
|
|
|
heaps.ids = swig_ptr(self.I) |
|
800
|
|
|
heaps.heapify() |
|
801
|
|
|
self.heaps = heaps |
|
802
|
|
|
|
|
803
|
|
|
def add_result(self, D, I): |
|
804
|
|
|
"""D, I do not need to be in a particular order (heap or sorted)""" |
|
805
|
|
|
assert D.shape == (self.nq, self.k) |
|
806
|
|
|
assert I.shape == (self.nq, self.k) |
|
807
|
|
|
self.heaps.addn_with_ids( |
|
808
|
|
|
self.k, faiss.swig_ptr(D), |
|
|
|
|
|
|
809
|
|
|
faiss.swig_ptr(I), self.k) |
|
810
|
|
|
|
|
811
|
|
|
def finalize(self): |
|
812
|
|
|
self.heaps.reorder() |
|
813
|
|
|
|