1
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates. |
2
|
|
|
# |
3
|
|
|
# This source code is licensed under the MIT license found in the |
4
|
|
|
# LICENSE file in the root directory of this source tree. |
5
|
|
|
|
6
|
|
|
#@nolint |
7
|
|
|
|
8
|
|
|
# not linting this file because it imports * form swigfaiss, which |
9
|
|
|
# causes a ton of useless warnings. |
10
|
|
|
|
11
|
|
|
import numpy as np |
12
|
|
|
import sys |
13
|
|
|
import inspect |
14
|
|
|
import pdb |
15
|
|
|
import platform |
16
|
|
|
import subprocess |
17
|
|
|
import logging |
18
|
|
|
|
19
|
|
|
|
20
|
|
|
logger = logging.getLogger(__name__) |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
def instruction_set(): |
24
|
|
|
if platform.system() == "Darwin": |
25
|
|
|
if subprocess.check_output(["/usr/sbin/sysctl", "hw.optional.avx2_0"])[-1] == '1': |
26
|
|
|
return "AVX2" |
27
|
|
|
else: |
28
|
|
|
return "default" |
29
|
|
|
elif platform.system() == "Linux": |
30
|
|
|
import numpy.distutils.cpuinfo |
31
|
|
|
if "avx2" in numpy.distutils.cpuinfo.cpu.info[0].get('flags', ""): |
32
|
|
|
return "AVX2" |
33
|
|
|
else: |
34
|
|
|
return "default" |
35
|
|
|
|
36
|
|
|
|
37
|
|
|
try: |
38
|
|
|
instr_set = instruction_set() |
39
|
|
|
if instr_set == "AVX2": |
40
|
|
|
logger.info("Loading faiss with AVX2 support.") |
41
|
|
|
from .swigfaiss_avx2 import * |
42
|
|
|
else: |
43
|
|
|
logger.info("Loading faiss.") |
44
|
|
|
from .swigfaiss import * |
45
|
|
|
|
46
|
|
|
except ImportError: |
47
|
|
|
# we import * so that the symbol X can be accessed as faiss.X |
48
|
|
|
logger.info("Loading faiss.") |
49
|
|
|
from .swigfaiss import * |
50
|
|
|
|
51
|
|
|
|
52
|
|
|
__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR, |
|
|
|
|
53
|
|
|
FAISS_VERSION_MINOR, |
|
|
|
|
54
|
|
|
FAISS_VERSION_PATCH) |
|
|
|
|
55
|
|
|
|
56
|
|
|
################################################################## |
57
|
|
|
# The functions below add or replace some methods for classes |
58
|
|
|
# this is to be able to pass in numpy arrays directly |
59
|
|
|
# The C++ version of the classnames will be suffixed with _c |
60
|
|
|
################################################################## |
61
|
|
|
|
62
|
|
|
|
63
|
|
|
def replace_method(the_class, name, replacement, ignore_missing=False): |
64
|
|
|
try: |
65
|
|
|
orig_method = getattr(the_class, name) |
66
|
|
|
except AttributeError: |
67
|
|
|
if ignore_missing: |
68
|
|
|
return |
69
|
|
|
raise |
70
|
|
|
if orig_method.__name__ == 'replacement_' + name: |
71
|
|
|
# replacement was done in parent class |
72
|
|
|
return |
73
|
|
|
setattr(the_class, name + '_c', orig_method) |
74
|
|
|
setattr(the_class, name, replacement) |
75
|
|
|
|
76
|
|
|
|
77
|
|
|
def handle_Clustering(): |
78
|
|
|
def replacement_train(self, x, index, weights=None): |
79
|
|
|
n, d = x.shape |
80
|
|
|
assert d == self.d |
81
|
|
|
if weights is not None: |
82
|
|
|
assert weights.shape == (n, ) |
83
|
|
|
self.train_c(n, swig_ptr(x), index, swig_ptr(weights)) |
|
|
|
|
84
|
|
|
else: |
85
|
|
|
self.train_c(n, swig_ptr(x), index) |
86
|
|
|
def replacement_train_encoded(self, x, codec, index, weights=None): |
87
|
|
|
n, d = x.shape |
88
|
|
|
assert d == codec.sa_code_size() |
89
|
|
|
assert codec.d == index.d |
90
|
|
|
if weights is not None: |
91
|
|
|
assert weights.shape == (n, ) |
92
|
|
|
self.train_encoded_c(n, swig_ptr(x), codec, index, swig_ptr(weights)) |
|
|
|
|
93
|
|
|
else: |
94
|
|
|
self.train_encoded_c(n, swig_ptr(x), codec, index) |
95
|
|
|
replace_method(Clustering, 'train', replacement_train) |
|
|
|
|
96
|
|
|
replace_method(Clustering, 'train_encoded', replacement_train_encoded) |
97
|
|
|
|
98
|
|
|
|
99
|
|
|
handle_Clustering() |
100
|
|
|
|
101
|
|
|
|
102
|
|
|
def handle_Quantizer(the_class): |
103
|
|
|
|
104
|
|
|
def replacement_train(self, x): |
105
|
|
|
n, d = x.shape |
106
|
|
|
assert d == self.d |
107
|
|
|
self.train_c(n, swig_ptr(x)) |
|
|
|
|
108
|
|
|
|
109
|
|
|
def replacement_compute_codes(self, x): |
110
|
|
|
n, d = x.shape |
111
|
|
|
assert d == self.d |
112
|
|
|
codes = np.empty((n, self.code_size), dtype='uint8') |
113
|
|
|
self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n) |
|
|
|
|
114
|
|
|
return codes |
115
|
|
|
|
116
|
|
|
def replacement_decode(self, codes): |
117
|
|
|
n, cs = codes.shape |
118
|
|
|
assert cs == self.code_size |
119
|
|
|
x = np.empty((n, self.d), dtype='float32') |
120
|
|
|
self.decode_c(swig_ptr(codes), swig_ptr(x), n) |
|
|
|
|
121
|
|
|
return x |
122
|
|
|
|
123
|
|
|
replace_method(the_class, 'train', replacement_train) |
124
|
|
|
replace_method(the_class, 'compute_codes', replacement_compute_codes) |
125
|
|
|
replace_method(the_class, 'decode', replacement_decode) |
126
|
|
|
|
127
|
|
|
|
128
|
|
|
handle_Quantizer(ProductQuantizer) |
|
|
|
|
129
|
|
|
handle_Quantizer(ScalarQuantizer) |
|
|
|
|
130
|
|
|
|
131
|
|
|
|
132
|
|
|
def handle_Index(the_class): |
133
|
|
|
|
134
|
|
|
def replacement_add(self, x): |
135
|
|
|
assert x.flags.contiguous |
136
|
|
|
n, d = x.shape |
137
|
|
|
assert d == self.d |
138
|
|
|
self.add_c(n, swig_ptr(x)) |
|
|
|
|
139
|
|
|
|
140
|
|
|
def replacement_add_with_ids(self, x, ids): |
141
|
|
|
n, d = x.shape |
142
|
|
|
assert d == self.d |
143
|
|
|
assert ids.shape == (n, ), 'not same nb of vectors as ids' |
144
|
|
|
self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) |
|
|
|
|
145
|
|
|
|
146
|
|
|
def replacement_assign(self, x, k): |
147
|
|
|
n, d = x.shape |
148
|
|
|
assert d == self.d |
149
|
|
|
labels = np.empty((n, k), dtype=np.int64) |
150
|
|
|
self.assign_c(n, swig_ptr(x), swig_ptr(labels), k) |
|
|
|
|
151
|
|
|
return labels |
152
|
|
|
|
153
|
|
|
def replacement_train(self, x): |
154
|
|
|
assert x.flags.contiguous |
155
|
|
|
n, d = x.shape |
156
|
|
|
assert d == self.d |
157
|
|
|
self.train_c(n, swig_ptr(x)) |
|
|
|
|
158
|
|
|
|
159
|
|
|
def replacement_search(self, x, k): |
160
|
|
|
n, d = x.shape |
161
|
|
|
assert d == self.d |
162
|
|
|
distances = np.empty((n, k), dtype=np.float32) |
163
|
|
|
labels = np.empty((n, k), dtype=np.int64) |
164
|
|
|
self.search_c(n, swig_ptr(x), |
|
|
|
|
165
|
|
|
k, swig_ptr(distances), |
166
|
|
|
swig_ptr(labels)) |
167
|
|
|
return distances, labels |
168
|
|
|
|
169
|
|
|
def replacement_search_and_reconstruct(self, x, k): |
170
|
|
|
n, d = x.shape |
171
|
|
|
assert d == self.d |
172
|
|
|
distances = np.empty((n, k), dtype=np.float32) |
173
|
|
|
labels = np.empty((n, k), dtype=np.int64) |
174
|
|
|
recons = np.empty((n, k, d), dtype=np.float32) |
175
|
|
|
self.search_and_reconstruct_c(n, swig_ptr(x), |
|
|
|
|
176
|
|
|
k, swig_ptr(distances), |
177
|
|
|
swig_ptr(labels), |
178
|
|
|
swig_ptr(recons)) |
179
|
|
|
return distances, labels, recons |
180
|
|
|
|
181
|
|
|
def replacement_remove_ids(self, x): |
182
|
|
|
if isinstance(x, IDSelector): |
|
|
|
|
183
|
|
|
sel = x |
184
|
|
|
else: |
185
|
|
|
assert x.ndim == 1 |
186
|
|
|
index_ivf = try_extract_index_ivf (self) |
|
|
|
|
187
|
|
|
if index_ivf and index_ivf.direct_map.type == DirectMap.Hashtable: |
|
|
|
|
188
|
|
|
sel = IDSelectorArray(x.size, swig_ptr(x)) |
|
|
|
|
189
|
|
|
else: |
190
|
|
|
sel = IDSelectorBatch(x.size, swig_ptr(x)) |
|
|
|
|
191
|
|
|
return self.remove_ids_c(sel) |
192
|
|
|
|
193
|
|
|
def replacement_reconstruct(self, key): |
194
|
|
|
x = np.empty(self.d, dtype=np.float32) |
195
|
|
|
self.reconstruct_c(key, swig_ptr(x)) |
|
|
|
|
196
|
|
|
return x |
197
|
|
|
|
198
|
|
|
def replacement_reconstruct_n(self, n0, ni): |
199
|
|
|
x = np.empty((ni, self.d), dtype=np.float32) |
200
|
|
|
self.reconstruct_n_c(n0, ni, swig_ptr(x)) |
|
|
|
|
201
|
|
|
return x |
202
|
|
|
|
203
|
|
|
def replacement_update_vectors(self, keys, x): |
204
|
|
|
n = keys.size |
205
|
|
|
assert keys.shape == (n, ) |
206
|
|
|
assert x.shape == (n, self.d) |
207
|
|
|
self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x)) |
|
|
|
|
208
|
|
|
|
209
|
|
View Code Duplication |
def replacement_range_search(self, x, thresh): |
|
|
|
|
210
|
|
|
n, d = x.shape |
211
|
|
|
assert d == self.d |
212
|
|
|
res = RangeSearchResult(n) |
|
|
|
|
213
|
|
|
self.range_search_c(n, swig_ptr(x), thresh, res) |
|
|
|
|
214
|
|
|
# get pointers and copy them |
215
|
|
|
lims = rev_swig_ptr(res.lims, n + 1).copy() |
|
|
|
|
216
|
|
|
nd = int(lims[-1]) |
217
|
|
|
D = rev_swig_ptr(res.distances, nd).copy() |
218
|
|
|
I = rev_swig_ptr(res.labels, nd).copy() |
219
|
|
|
return lims, D, I |
220
|
|
|
|
221
|
|
|
def replacement_sa_encode(self, x): |
222
|
|
|
n, d = x.shape |
223
|
|
|
assert d == self.d |
224
|
|
|
codes = np.empty((n, self.sa_code_size()), dtype='uint8') |
225
|
|
|
self.sa_encode_c(n, swig_ptr(x), swig_ptr(codes)) |
|
|
|
|
226
|
|
|
return codes |
227
|
|
|
|
228
|
|
|
def replacement_sa_decode(self, codes): |
229
|
|
|
n, cs = codes.shape |
230
|
|
|
assert cs == self.sa_code_size() |
231
|
|
|
x = np.empty((n, self.d), dtype='float32') |
232
|
|
|
self.sa_decode_c(n, swig_ptr(codes), swig_ptr(x)) |
|
|
|
|
233
|
|
|
return x |
234
|
|
|
|
235
|
|
|
replace_method(the_class, 'add', replacement_add) |
236
|
|
|
replace_method(the_class, 'add_with_ids', replacement_add_with_ids) |
237
|
|
|
replace_method(the_class, 'assign', replacement_assign) |
238
|
|
|
replace_method(the_class, 'train', replacement_train) |
239
|
|
|
replace_method(the_class, 'search', replacement_search) |
240
|
|
|
replace_method(the_class, 'remove_ids', replacement_remove_ids) |
241
|
|
|
replace_method(the_class, 'reconstruct', replacement_reconstruct) |
242
|
|
|
replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) |
243
|
|
|
replace_method(the_class, 'range_search', replacement_range_search) |
244
|
|
|
replace_method(the_class, 'update_vectors', replacement_update_vectors, |
245
|
|
|
ignore_missing=True) |
246
|
|
|
replace_method(the_class, 'search_and_reconstruct', |
247
|
|
|
replacement_search_and_reconstruct, ignore_missing=True) |
248
|
|
|
replace_method(the_class, 'sa_encode', replacement_sa_encode) |
249
|
|
|
replace_method(the_class, 'sa_decode', replacement_sa_decode) |
250
|
|
|
|
251
|
|
|
def handle_IndexBinary(the_class): |
252
|
|
|
|
253
|
|
|
def replacement_add(self, x): |
254
|
|
|
assert x.flags.contiguous |
255
|
|
|
n, d = x.shape |
256
|
|
|
assert d * 8 == self.d |
257
|
|
|
self.add_c(n, swig_ptr(x)) |
|
|
|
|
258
|
|
|
|
259
|
|
|
def replacement_add_with_ids(self, x, ids): |
260
|
|
|
n, d = x.shape |
261
|
|
|
assert d * 8 == self.d |
262
|
|
|
assert ids.shape == (n, ), 'not same nb of vectors as ids' |
263
|
|
|
self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) |
|
|
|
|
264
|
|
|
|
265
|
|
|
def replacement_train(self, x): |
266
|
|
|
assert x.flags.contiguous |
267
|
|
|
n, d = x.shape |
268
|
|
|
assert d * 8 == self.d |
269
|
|
|
self.train_c(n, swig_ptr(x)) |
|
|
|
|
270
|
|
|
|
271
|
|
|
def replacement_reconstruct(self, key): |
272
|
|
|
x = np.empty(self.d // 8, dtype=np.uint8) |
273
|
|
|
self.reconstruct_c(key, swig_ptr(x)) |
|
|
|
|
274
|
|
|
return x |
275
|
|
|
|
276
|
|
|
def replacement_search(self, x, k): |
277
|
|
|
n, d = x.shape |
278
|
|
|
assert d * 8 == self.d |
279
|
|
|
distances = np.empty((n, k), dtype=np.int32) |
280
|
|
|
labels = np.empty((n, k), dtype=np.int64) |
281
|
|
|
self.search_c(n, swig_ptr(x), |
|
|
|
|
282
|
|
|
k, swig_ptr(distances), |
283
|
|
|
swig_ptr(labels)) |
284
|
|
|
return distances, labels |
285
|
|
|
|
286
|
|
View Code Duplication |
def replacement_range_search(self, x, thresh): |
|
|
|
|
287
|
|
|
n, d = x.shape |
288
|
|
|
assert d * 8 == self.d |
289
|
|
|
res = RangeSearchResult(n) |
|
|
|
|
290
|
|
|
self.range_search_c(n, swig_ptr(x), thresh, res) |
|
|
|
|
291
|
|
|
# get pointers and copy them |
292
|
|
|
lims = rev_swig_ptr(res.lims, n + 1).copy() |
|
|
|
|
293
|
|
|
nd = int(lims[-1]) |
294
|
|
|
D = rev_swig_ptr(res.distances, nd).copy() |
295
|
|
|
I = rev_swig_ptr(res.labels, nd).copy() |
296
|
|
|
return lims, D, I |
297
|
|
|
|
298
|
|
|
def replacement_remove_ids(self, x): |
299
|
|
|
if isinstance(x, IDSelector): |
|
|
|
|
300
|
|
|
sel = x |
301
|
|
|
else: |
302
|
|
|
assert x.ndim == 1 |
303
|
|
|
sel = IDSelectorBatch(x.size, swig_ptr(x)) |
|
|
|
|
304
|
|
|
return self.remove_ids_c(sel) |
305
|
|
|
|
306
|
|
|
replace_method(the_class, 'add', replacement_add) |
307
|
|
|
replace_method(the_class, 'add_with_ids', replacement_add_with_ids) |
308
|
|
|
replace_method(the_class, 'train', replacement_train) |
309
|
|
|
replace_method(the_class, 'search', replacement_search) |
310
|
|
|
replace_method(the_class, 'range_search', replacement_range_search) |
311
|
|
|
replace_method(the_class, 'reconstruct', replacement_reconstruct) |
312
|
|
|
replace_method(the_class, 'remove_ids', replacement_remove_ids) |
313
|
|
|
|
314
|
|
|
|
315
|
|
|
def handle_VectorTransform(the_class): |
316
|
|
|
|
317
|
|
|
def apply_method(self, x): |
318
|
|
|
assert x.flags.contiguous |
319
|
|
|
n, d = x.shape |
320
|
|
|
assert d == self.d_in |
321
|
|
|
y = np.empty((n, self.d_out), dtype=np.float32) |
322
|
|
|
self.apply_noalloc(n, swig_ptr(x), swig_ptr(y)) |
|
|
|
|
323
|
|
|
return y |
324
|
|
|
|
325
|
|
|
def replacement_reverse_transform(self, x): |
326
|
|
|
n, d = x.shape |
327
|
|
|
assert d == self.d_out |
328
|
|
|
y = np.empty((n, self.d_in), dtype=np.float32) |
329
|
|
|
self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y)) |
|
|
|
|
330
|
|
|
return y |
331
|
|
|
|
332
|
|
|
def replacement_vt_train(self, x): |
333
|
|
|
assert x.flags.contiguous |
334
|
|
|
n, d = x.shape |
335
|
|
|
assert d == self.d_in |
336
|
|
|
self.train_c(n, swig_ptr(x)) |
|
|
|
|
337
|
|
|
|
338
|
|
|
replace_method(the_class, 'train', replacement_vt_train) |
339
|
|
|
# apply is reserved in Pyton... |
340
|
|
|
the_class.apply_py = apply_method |
341
|
|
|
replace_method(the_class, 'reverse_transform', |
342
|
|
|
replacement_reverse_transform) |
343
|
|
|
|
344
|
|
|
|
345
|
|
|
def handle_AutoTuneCriterion(the_class): |
346
|
|
|
def replacement_set_groundtruth(self, D, I): |
347
|
|
|
if D: |
348
|
|
|
assert I.shape == D.shape |
349
|
|
|
self.nq, self.gt_nnn = I.shape |
350
|
|
|
self.set_groundtruth_c( |
351
|
|
|
self.gt_nnn, swig_ptr(D) if D else None, swig_ptr(I)) |
|
|
|
|
352
|
|
|
|
353
|
|
|
def replacement_evaluate(self, D, I): |
354
|
|
|
assert I.shape == D.shape |
355
|
|
|
assert I.shape == (self.nq, self.nnn) |
356
|
|
|
return self.evaluate_c(swig_ptr(D), swig_ptr(I)) |
|
|
|
|
357
|
|
|
|
358
|
|
|
replace_method(the_class, 'set_groundtruth', replacement_set_groundtruth) |
359
|
|
|
replace_method(the_class, 'evaluate', replacement_evaluate) |
360
|
|
|
|
361
|
|
|
|
362
|
|
|
def handle_ParameterSpace(the_class): |
363
|
|
|
def replacement_explore(self, index, xq, crit): |
364
|
|
|
assert xq.shape == (crit.nq, index.d) |
365
|
|
|
ops = OperatingPoints() |
|
|
|
|
366
|
|
|
self.explore_c(index, crit.nq, swig_ptr(xq), |
|
|
|
|
367
|
|
|
crit, ops) |
368
|
|
|
return ops |
369
|
|
|
replace_method(the_class, 'explore', replacement_explore) |
370
|
|
|
|
371
|
|
|
|
372
|
|
|
def handle_MatrixStats(the_class): |
373
|
|
|
original_init = the_class.__init__ |
374
|
|
|
|
375
|
|
|
def replacement_init(self, m): |
376
|
|
|
assert len(m.shape) == 2 |
377
|
|
|
original_init(self, m.shape[0], m.shape[1], swig_ptr(m)) |
|
|
|
|
378
|
|
|
|
379
|
|
|
the_class.__init__ = replacement_init |
380
|
|
|
|
381
|
|
|
handle_MatrixStats(MatrixStats) |
|
|
|
|
382
|
|
|
|
383
|
|
|
|
384
|
|
|
this_module = sys.modules[__name__] |
385
|
|
|
|
386
|
|
|
|
387
|
|
|
for symbol in dir(this_module): |
388
|
|
|
obj = getattr(this_module, symbol) |
389
|
|
|
# print symbol, isinstance(obj, (type, types.ClassType)) |
390
|
|
|
if inspect.isclass(obj): |
391
|
|
|
the_class = obj |
392
|
|
|
if issubclass(the_class, Index): |
|
|
|
|
393
|
|
|
handle_Index(the_class) |
394
|
|
|
|
395
|
|
|
if issubclass(the_class, IndexBinary): |
|
|
|
|
396
|
|
|
handle_IndexBinary(the_class) |
397
|
|
|
|
398
|
|
|
if issubclass(the_class, VectorTransform): |
|
|
|
|
399
|
|
|
handle_VectorTransform(the_class) |
400
|
|
|
|
401
|
|
|
if issubclass(the_class, AutoTuneCriterion): |
|
|
|
|
402
|
|
|
handle_AutoTuneCriterion(the_class) |
403
|
|
|
|
404
|
|
|
if issubclass(the_class, ParameterSpace): |
|
|
|
|
405
|
|
|
handle_ParameterSpace(the_class) |
406
|
|
|
|
407
|
|
|
|
408
|
|
|
########################################### |
409
|
|
|
# Add Python references to objects |
410
|
|
|
# we do this at the Python class wrapper level. |
411
|
|
|
########################################### |
412
|
|
|
|
413
|
|
|
def add_ref_in_constructor(the_class, parameter_no): |
414
|
|
|
# adds a reference to parameter parameter_no in self |
415
|
|
|
# so that that parameter does not get deallocated before self |
416
|
|
|
original_init = the_class.__init__ |
417
|
|
|
|
418
|
|
|
def replacement_init(self, *args): |
419
|
|
|
original_init(self, *args) |
420
|
|
|
self.referenced_objects = [args[parameter_no]] |
421
|
|
|
|
422
|
|
|
def replacement_init_multiple(self, *args): |
423
|
|
|
original_init(self, *args) |
424
|
|
|
pset = parameter_no[len(args)] |
425
|
|
|
self.referenced_objects = [args[no] for no in pset] |
426
|
|
|
|
427
|
|
|
if type(parameter_no) == dict: |
428
|
|
|
# a list of parameters to keep, depending on the number of arguments |
429
|
|
|
the_class.__init__ = replacement_init_multiple |
430
|
|
|
else: |
431
|
|
|
the_class.__init__ = replacement_init |
432
|
|
|
|
433
|
|
|
def add_ref_in_method(the_class, method_name, parameter_no): |
434
|
|
|
original_method = getattr(the_class, method_name) |
435
|
|
|
def replacement_method(self, *args): |
436
|
|
|
ref = args[parameter_no] |
437
|
|
|
if not hasattr(self, 'referenced_objects'): |
438
|
|
|
self.referenced_objects = [ref] |
439
|
|
|
else: |
440
|
|
|
self.referenced_objects.append(ref) |
441
|
|
|
return original_method(self, *args) |
442
|
|
|
setattr(the_class, method_name, replacement_method) |
443
|
|
|
|
444
|
|
|
def add_ref_in_function(function_name, parameter_no): |
445
|
|
|
# assumes the function returns an object |
446
|
|
|
original_function = getattr(this_module, function_name) |
447
|
|
|
def replacement_function(*args): |
448
|
|
|
result = original_function(*args) |
449
|
|
|
ref = args[parameter_no] |
450
|
|
|
result.referenced_objects = [ref] |
451
|
|
|
return result |
452
|
|
|
setattr(this_module, function_name, replacement_function) |
453
|
|
|
|
454
|
|
|
add_ref_in_constructor(IndexIVFFlat, 0) |
|
|
|
|
455
|
|
|
add_ref_in_constructor(IndexIVFFlatDedup, 0) |
|
|
|
|
456
|
|
|
add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]}) |
|
|
|
|
457
|
|
|
add_ref_in_method(IndexPreTransform, 'prepend_transform', 0) |
458
|
|
|
add_ref_in_constructor(IndexIVFPQ, 0) |
|
|
|
|
459
|
|
|
add_ref_in_constructor(IndexIVFPQR, 0) |
|
|
|
|
460
|
|
|
add_ref_in_constructor(Index2Layer, 0) |
|
|
|
|
461
|
|
|
add_ref_in_constructor(Level1Quantizer, 0) |
|
|
|
|
462
|
|
|
add_ref_in_constructor(IndexIVFScalarQuantizer, 0) |
|
|
|
|
463
|
|
|
add_ref_in_constructor(IndexIDMap, 0) |
|
|
|
|
464
|
|
|
add_ref_in_constructor(IndexIDMap2, 0) |
|
|
|
|
465
|
|
|
add_ref_in_constructor(IndexHNSW, 0) |
|
|
|
|
466
|
|
|
add_ref_in_method(IndexShards, 'add_shard', 0) |
|
|
|
|
467
|
|
|
add_ref_in_method(IndexBinaryShards, 'add_shard', 0) |
|
|
|
|
468
|
|
|
add_ref_in_constructor(IndexRefineFlat, 0) |
|
|
|
|
469
|
|
|
add_ref_in_constructor(IndexBinaryIVF, 0) |
|
|
|
|
470
|
|
|
add_ref_in_constructor(IndexBinaryFromFloat, 0) |
|
|
|
|
471
|
|
|
add_ref_in_constructor(IndexBinaryIDMap, 0) |
|
|
|
|
472
|
|
|
add_ref_in_constructor(IndexBinaryIDMap2, 0) |
|
|
|
|
473
|
|
|
|
474
|
|
|
add_ref_in_method(IndexReplicas, 'addIndex', 0) |
|
|
|
|
475
|
|
|
add_ref_in_method(IndexBinaryReplicas, 'addIndex', 0) |
|
|
|
|
476
|
|
|
|
477
|
|
|
add_ref_in_constructor(BufferedIOWriter, 0) |
|
|
|
|
478
|
|
|
add_ref_in_constructor(BufferedIOReader, 0) |
|
|
|
|
479
|
|
|
|
480
|
|
|
# seems really marginal... |
481
|
|
|
# remove_ref_from_method(IndexReplicas, 'removeIndex', 0) |
482
|
|
|
|
483
|
|
|
if hasattr(this_module, 'GpuIndexFlat'): |
484
|
|
|
# handle all the GPUResources refs |
485
|
|
|
add_ref_in_function('index_cpu_to_gpu', 0) |
486
|
|
|
add_ref_in_constructor(GpuIndexFlat, 0) |
|
|
|
|
487
|
|
|
add_ref_in_constructor(GpuIndexFlatIP, 0) |
|
|
|
|
488
|
|
|
add_ref_in_constructor(GpuIndexFlatL2, 0) |
|
|
|
|
489
|
|
|
add_ref_in_constructor(GpuIndexIVFFlat, 0) |
|
|
|
|
490
|
|
|
add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 0) |
|
|
|
|
491
|
|
|
add_ref_in_constructor(GpuIndexIVFPQ, 0) |
|
|
|
|
492
|
|
|
add_ref_in_constructor(GpuIndexBinaryFlat, 0) |
|
|
|
|
493
|
|
|
|
494
|
|
|
|
495
|
|
|
|
496
|
|
|
########################################### |
497
|
|
|
# GPU functions |
498
|
|
|
########################################### |
499
|
|
|
|
500
|
|
|
|
501
|
|
|
def index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None): |
502
|
|
|
""" builds the C++ vectors for the GPU indices and the |
503
|
|
|
resources. Handles the case where the resources are assigned to |
504
|
|
|
the list of GPUs """ |
505
|
|
|
if gpus is None: |
506
|
|
|
gpus = range(len(resources)) |
507
|
|
|
vres = GpuResourcesVector() |
|
|
|
|
508
|
|
|
vdev = IntVector() |
|
|
|
|
509
|
|
|
for i, res in zip(gpus, resources): |
510
|
|
|
vdev.push_back(i) |
511
|
|
|
vres.push_back(res) |
512
|
|
|
index = index_cpu_to_gpu_multiple(vres, vdev, index, co) |
|
|
|
|
513
|
|
|
index.referenced_objects = resources |
514
|
|
|
return index |
515
|
|
|
|
516
|
|
|
|
517
|
|
|
def index_cpu_to_all_gpus(index, co=None, ngpu=-1): |
518
|
|
|
index_gpu = index_cpu_to_gpus_list(index, co=co, gpus=None, ngpu=ngpu) |
519
|
|
|
return index_gpu |
520
|
|
|
|
521
|
|
|
|
522
|
|
|
def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1): |
523
|
|
|
""" Here we can pass list of GPU ids as a parameter or ngpu to |
524
|
|
|
use first n GPU's. gpus mut be a list or None""" |
525
|
|
|
if (gpus is None) and (ngpu == -1): # All blank |
526
|
|
|
gpus = range(get_num_gpus()) |
|
|
|
|
527
|
|
|
elif (gpus is None) and (ngpu != -1): # Get number of GPU's only |
528
|
|
|
gpus = range(ngpu) |
529
|
|
|
res = [StandardGpuResources() for _ in gpus] |
|
|
|
|
530
|
|
|
index_gpu = index_cpu_to_gpu_multiple_py(res, index, co, gpus) |
531
|
|
|
return index_gpu |
532
|
|
|
|
533
|
|
|
|
534
|
|
|
########################################### |
535
|
|
|
# numpy array / std::vector conversions |
536
|
|
|
########################################### |
537
|
|
|
|
538
|
|
|
# mapping from vector names in swigfaiss.swig and the numpy dtype names |
539
|
|
|
vector_name_map = { |
540
|
|
|
'Float': 'float32', |
541
|
|
|
'Byte': 'uint8', |
542
|
|
|
'Char': 'int8', |
543
|
|
|
'Uint64': 'uint64', |
544
|
|
|
'Long': 'int64', |
545
|
|
|
'Int': 'int32', |
546
|
|
|
'Double': 'float64' |
547
|
|
|
} |
548
|
|
|
|
549
|
|
|
def vector_to_array(v): |
550
|
|
|
""" convert a C++ vector to a numpy array """ |
551
|
|
|
classname = v.__class__.__name__ |
552
|
|
|
assert classname.endswith('Vector') |
553
|
|
|
dtype = np.dtype(vector_name_map[classname[:-6]]) |
554
|
|
|
a = np.empty(v.size(), dtype=dtype) |
555
|
|
|
if v.size() > 0: |
556
|
|
|
memcpy(swig_ptr(a), v.data(), a.nbytes) |
|
|
|
|
557
|
|
|
return a |
558
|
|
|
|
559
|
|
|
|
560
|
|
|
def vector_float_to_array(v): |
561
|
|
|
return vector_to_array(v) |
562
|
|
|
|
563
|
|
|
|
564
|
|
|
def copy_array_to_vector(a, v): |
565
|
|
|
""" copy a numpy array to a vector """ |
566
|
|
|
n, = a.shape |
567
|
|
|
classname = v.__class__.__name__ |
568
|
|
|
assert classname.endswith('Vector') |
569
|
|
|
dtype = np.dtype(vector_name_map[classname[:-6]]) |
570
|
|
|
assert dtype == a.dtype, ( |
571
|
|
|
'cannot copy a %s array to a %s (should be %s)' % ( |
572
|
|
|
a.dtype, classname, dtype)) |
573
|
|
|
v.resize(n) |
574
|
|
|
if n > 0: |
575
|
|
|
memcpy(v.data(), swig_ptr(a), a.nbytes) |
|
|
|
|
576
|
|
|
|
577
|
|
|
|
578
|
|
|
########################################### |
579
|
|
|
# Wrapper for a few functions |
580
|
|
|
########################################### |
581
|
|
|
|
582
|
|
View Code Duplication |
def kmin(array, k): |
|
|
|
|
583
|
|
|
"""return k smallest values (and their indices) of the lines of a |
584
|
|
|
float32 array""" |
585
|
|
|
m, n = array.shape |
586
|
|
|
I = np.zeros((m, k), dtype='int64') |
587
|
|
|
D = np.zeros((m, k), dtype='float32') |
588
|
|
|
ha = float_maxheap_array_t() |
|
|
|
|
589
|
|
|
ha.ids = swig_ptr(I) |
|
|
|
|
590
|
|
|
ha.val = swig_ptr(D) |
591
|
|
|
ha.nh = m |
592
|
|
|
ha.k = k |
593
|
|
|
ha.heapify() |
594
|
|
|
ha.addn(n, swig_ptr(array)) |
595
|
|
|
ha.reorder() |
596
|
|
|
return D, I |
597
|
|
|
|
598
|
|
|
|
599
|
|
View Code Duplication |
def kmax(array, k): |
|
|
|
|
600
|
|
|
"""return k largest values (and their indices) of the lines of a |
601
|
|
|
float32 array""" |
602
|
|
|
m, n = array.shape |
603
|
|
|
I = np.zeros((m, k), dtype='int64') |
604
|
|
|
D = np.zeros((m, k), dtype='float32') |
605
|
|
|
ha = float_minheap_array_t() |
|
|
|
|
606
|
|
|
ha.ids = swig_ptr(I) |
|
|
|
|
607
|
|
|
ha.val = swig_ptr(D) |
608
|
|
|
ha.nh = m |
609
|
|
|
ha.k = k |
610
|
|
|
ha.heapify() |
611
|
|
|
ha.addn(n, swig_ptr(array)) |
612
|
|
|
ha.reorder() |
613
|
|
|
return D, I |
614
|
|
|
|
615
|
|
|
|
616
|
|
|
def pairwise_distances(xq, xb, mt=METRIC_L2, metric_arg=0): |
617
|
|
|
"""compute the whole pairwise distance matrix between two sets of |
618
|
|
|
vectors""" |
619
|
|
|
nq, d = xq.shape |
620
|
|
|
nb, d2 = xb.shape |
621
|
|
|
assert d == d2 |
622
|
|
|
dis = np.empty((nq, nb), dtype='float32') |
623
|
|
|
if mt == METRIC_L2: |
|
|
|
|
624
|
|
|
pairwise_L2sqr( |
|
|
|
|
625
|
|
|
d, nq, swig_ptr(xq), |
|
|
|
|
626
|
|
|
nb, swig_ptr(xb), |
627
|
|
|
swig_ptr(dis)) |
628
|
|
|
else: |
629
|
|
|
pairwise_extra_distances( |
|
|
|
|
630
|
|
|
d, nq, swig_ptr(xq), |
631
|
|
|
nb, swig_ptr(xb), |
632
|
|
|
mt, metric_arg, |
633
|
|
|
swig_ptr(dis)) |
634
|
|
|
return dis |
635
|
|
|
|
636
|
|
|
|
637
|
|
|
|
638
|
|
|
|
639
|
|
|
def rand(n, seed=12345): |
640
|
|
|
res = np.empty(n, dtype='float32') |
641
|
|
|
float_rand(swig_ptr(res), res.size, seed) |
|
|
|
|
642
|
|
|
return res |
643
|
|
|
|
644
|
|
|
|
645
|
|
|
def randint(n, seed=12345, vmax=None): |
646
|
|
|
res = np.empty(n, dtype='int64') |
647
|
|
|
if vmax is None: |
648
|
|
|
int64_rand(swig_ptr(res), res.size, seed) |
|
|
|
|
649
|
|
|
else: |
650
|
|
|
int64_rand_max(swig_ptr(res), res.size, vmax, seed) |
|
|
|
|
651
|
|
|
return res |
652
|
|
|
|
653
|
|
|
lrand = randint |
654
|
|
|
|
655
|
|
|
def randn(n, seed=12345): |
656
|
|
|
res = np.empty(n, dtype='float32') |
657
|
|
|
float_randn(swig_ptr(res), res.size, seed) |
|
|
|
|
658
|
|
|
return res |
659
|
|
|
|
660
|
|
|
|
661
|
|
|
def eval_intersection(I1, I2): |
662
|
|
|
""" size of intersection between each line of two result tables""" |
663
|
|
|
n = I1.shape[0] |
664
|
|
|
assert I2.shape[0] == n |
665
|
|
|
k1, k2 = I1.shape[1], I2.shape[1] |
666
|
|
|
ninter = 0 |
667
|
|
|
for i in range(n): |
668
|
|
|
ninter += ranklist_intersection_size( |
|
|
|
|
669
|
|
|
k1, swig_ptr(I1[i]), k2, swig_ptr(I2[i])) |
|
|
|
|
670
|
|
|
return ninter |
671
|
|
|
|
672
|
|
|
|
673
|
|
|
def normalize_L2(x): |
674
|
|
|
fvec_renorm_L2(x.shape[1], x.shape[0], swig_ptr(x)) |
|
|
|
|
675
|
|
|
|
676
|
|
|
# MapLong2Long interface |
677
|
|
|
|
678
|
|
|
def replacement_map_add(self, keys, vals): |
679
|
|
|
n, = keys.shape |
680
|
|
|
assert (n,) == keys.shape |
681
|
|
|
self.add_c(n, swig_ptr(keys), swig_ptr(vals)) |
|
|
|
|
682
|
|
|
|
683
|
|
|
def replacement_map_search_multiple(self, keys): |
684
|
|
|
n, = keys.shape |
685
|
|
|
vals = np.empty(n, dtype='int64') |
686
|
|
|
self.search_multiple_c(n, swig_ptr(keys), swig_ptr(vals)) |
|
|
|
|
687
|
|
|
return vals |
688
|
|
|
|
689
|
|
|
replace_method(MapLong2Long, 'add', replacement_map_add) |
|
|
|
|
690
|
|
|
replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple) |
691
|
|
|
|
692
|
|
|
|
693
|
|
|
########################################### |
694
|
|
|
# Kmeans object |
695
|
|
|
########################################### |
696
|
|
|
|
697
|
|
|
|
698
|
|
|
class Kmeans: |
699
|
|
|
"""shallow wrapper around the Clustering object. The important method |
700
|
|
|
is train().""" |
701
|
|
|
|
702
|
|
|
def __init__(self, d, k, **kwargs): |
703
|
|
|
"""d: input dimension, k: nb of centroids. Additional |
704
|
|
|
parameters are passed on the ClusteringParameters object, |
705
|
|
|
including niter=25, verbose=False, spherical = False |
706
|
|
|
""" |
707
|
|
|
self.d = d |
708
|
|
|
self.k = k |
709
|
|
|
self.gpu = False |
710
|
|
|
self.cp = ClusteringParameters() |
|
|
|
|
711
|
|
|
for k, v in kwargs.items(): |
712
|
|
|
if k == 'gpu': |
713
|
|
|
self.gpu = v |
714
|
|
|
else: |
715
|
|
|
# if this raises an exception, it means that it is a non-existent field |
716
|
|
|
getattr(self.cp, k) |
717
|
|
|
setattr(self.cp, k, v) |
718
|
|
|
self.centroids = None |
719
|
|
|
|
720
|
|
|
def train(self, x, weights=None): |
721
|
|
|
n, d = x.shape |
722
|
|
|
assert d == self.d |
723
|
|
|
clus = Clustering(d, self.k, self.cp) |
|
|
|
|
724
|
|
|
if self.cp.spherical: |
725
|
|
|
self.index = IndexFlatIP(d) |
|
|
|
|
726
|
|
|
else: |
727
|
|
|
self.index = IndexFlatL2(d) |
|
|
|
|
728
|
|
|
if self.gpu: |
729
|
|
|
if self.gpu == True: |
730
|
|
|
ngpu = -1 |
731
|
|
|
else: |
732
|
|
|
ngpu = self.gpu |
733
|
|
|
self.index = index_cpu_to_all_gpus(self.index, ngpu=ngpu) |
734
|
|
|
clus.train(x, self.index, weights) |
735
|
|
|
centroids = vector_float_to_array(clus.centroids) |
736
|
|
|
self.centroids = centroids.reshape(self.k, d) |
737
|
|
|
stats = clus.iteration_stats |
738
|
|
|
self.obj = np.array([ |
739
|
|
|
stats.at(i).obj for i in range(stats.size()) |
740
|
|
|
]) |
741
|
|
|
return self.obj[-1] if self.obj.size > 0 else 0.0 |
742
|
|
|
|
743
|
|
|
def assign(self, x): |
744
|
|
|
assert self.centroids is not None, "should train before assigning" |
745
|
|
|
self.index.reset() |
746
|
|
|
self.index.add(self.centroids) |
747
|
|
|
D, I = self.index.search(x, 1) |
748
|
|
|
return D.ravel(), I.ravel() |
749
|
|
|
|
750
|
|
|
# IndexProxy was renamed to IndexReplicas, remap the old name for any old code |
751
|
|
|
# people may have |
752
|
|
|
IndexProxy = IndexReplicas |
753
|
|
|
ConcatenatedInvertedLists = HStackInvertedLists |
|
|
|
|
754
|
|
|
|
755
|
|
|
########################################### |
756
|
|
|
# serialization of indexes to byte arrays |
757
|
|
|
########################################### |
758
|
|
|
|
759
|
|
|
def serialize_index(index): |
760
|
|
|
""" convert an index to a numpy uint8 array """ |
761
|
|
|
writer = VectorIOWriter() |
|
|
|
|
762
|
|
|
write_index(index, writer) |
|
|
|
|
763
|
|
|
return vector_to_array(writer.data) |
764
|
|
|
|
765
|
|
|
def deserialize_index(data): |
766
|
|
|
reader = VectorIOReader() |
|
|
|
|
767
|
|
|
copy_array_to_vector(data, reader.data) |
768
|
|
|
return read_index(reader) |
|
|
|
|
769
|
|
|
|
770
|
|
|
def serialize_index_binary(index): |
771
|
|
|
""" convert an index to a numpy uint8 array """ |
772
|
|
|
writer = VectorIOWriter() |
|
|
|
|
773
|
|
|
write_index_binary(index, writer) |
|
|
|
|
774
|
|
|
return vector_to_array(writer.data) |
775
|
|
|
|
776
|
|
|
def deserialize_index_binary(data): |
777
|
|
|
reader = VectorIOReader() |
|
|
|
|
778
|
|
|
copy_array_to_vector(data, reader.data) |
779
|
|
|
return read_index_binary(reader) |
|
|
|
|
780
|
|
|
|
781
|
|
|
|
782
|
|
|
########################################### |
783
|
|
|
# ResultHeap |
784
|
|
|
########################################### |
785
|
|
|
|
786
|
|
View Code Duplication |
class ResultHeap: |
|
|
|
|
787
|
|
|
"""Accumulate query results from a sliced dataset. The final result will |
788
|
|
|
be in self.D, self.I.""" |
789
|
|
|
|
790
|
|
|
def __init__(self, nq, k): |
791
|
|
|
" nq: number of query vectors, k: number of results per query " |
792
|
|
|
self.I = np.zeros((nq, k), dtype='int64') |
793
|
|
|
self.D = np.zeros((nq, k), dtype='float32') |
794
|
|
|
self.nq, self.k = nq, k |
795
|
|
|
heaps = float_maxheap_array_t() |
|
|
|
|
796
|
|
|
heaps.k = k |
797
|
|
|
heaps.nh = nq |
798
|
|
|
heaps.val = swig_ptr(self.D) |
|
|
|
|
799
|
|
|
heaps.ids = swig_ptr(self.I) |
800
|
|
|
heaps.heapify() |
801
|
|
|
self.heaps = heaps |
802
|
|
|
|
803
|
|
|
def add_result(self, D, I): |
804
|
|
|
"""D, I do not need to be in a particular order (heap or sorted)""" |
805
|
|
|
assert D.shape == (self.nq, self.k) |
806
|
|
|
assert I.shape == (self.nq, self.k) |
807
|
|
|
self.heaps.addn_with_ids( |
808
|
|
|
self.k, faiss.swig_ptr(D), |
|
|
|
|
809
|
|
|
faiss.swig_ptr(I), self.k) |
810
|
|
|
|
811
|
|
|
def finalize(self): |
812
|
|
|
self.heaps.reorder() |
813
|
|
|
|