Total Complexity | 73 |
Total Lines | 813 |
Duplicated Lines | 9.72 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like faiss often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # Copyright (c) Facebook, Inc. and its affiliates. |
||
2 | # |
||
3 | # This source code is licensed under the MIT license found in the |
||
4 | # LICENSE file in the root directory of this source tree. |
||
5 | |||
6 | #@nolint |
||
7 | |||
8 | # not linting this file because it imports * form swigfaiss, which |
||
9 | # causes a ton of useless warnings. |
||
10 | |||
11 | import numpy as np |
||
12 | import sys |
||
13 | import inspect |
||
14 | import pdb |
||
15 | import platform |
||
16 | import subprocess |
||
17 | import logging |
||
18 | |||
19 | |||
20 | logger = logging.getLogger(__name__) |
||
21 | |||
22 | |||
23 | def instruction_set(): |
||
24 | if platform.system() == "Darwin": |
||
25 | if subprocess.check_output(["/usr/sbin/sysctl", "hw.optional.avx2_0"])[-1] == '1': |
||
26 | return "AVX2" |
||
27 | else: |
||
28 | return "default" |
||
29 | elif platform.system() == "Linux": |
||
30 | import numpy.distutils.cpuinfo |
||
31 | if "avx2" in numpy.distutils.cpuinfo.cpu.info[0].get('flags', ""): |
||
32 | return "AVX2" |
||
33 | else: |
||
34 | return "default" |
||
35 | |||
36 | |||
37 | try: |
||
38 | instr_set = instruction_set() |
||
39 | if instr_set == "AVX2": |
||
40 | logger.info("Loading faiss with AVX2 support.") |
||
41 | from .swigfaiss_avx2 import * |
||
42 | else: |
||
43 | logger.info("Loading faiss.") |
||
44 | from .swigfaiss import * |
||
45 | |||
46 | except ImportError: |
||
47 | # we import * so that the symbol X can be accessed as faiss.X |
||
48 | logger.info("Loading faiss.") |
||
49 | from .swigfaiss import * |
||
50 | |||
51 | |||
52 | __version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR, |
||
|
|||
53 | FAISS_VERSION_MINOR, |
||
54 | FAISS_VERSION_PATCH) |
||
55 | |||
56 | ################################################################## |
||
57 | # The functions below add or replace some methods for classes |
||
58 | # this is to be able to pass in numpy arrays directly |
||
59 | # The C++ version of the classnames will be suffixed with _c |
||
60 | ################################################################## |
||
61 | |||
62 | |||
63 | def replace_method(the_class, name, replacement, ignore_missing=False): |
||
64 | try: |
||
65 | orig_method = getattr(the_class, name) |
||
66 | except AttributeError: |
||
67 | if ignore_missing: |
||
68 | return |
||
69 | raise |
||
70 | if orig_method.__name__ == 'replacement_' + name: |
||
71 | # replacement was done in parent class |
||
72 | return |
||
73 | setattr(the_class, name + '_c', orig_method) |
||
74 | setattr(the_class, name, replacement) |
||
75 | |||
76 | |||
77 | def handle_Clustering(): |
||
78 | def replacement_train(self, x, index, weights=None): |
||
79 | n, d = x.shape |
||
80 | assert d == self.d |
||
81 | if weights is not None: |
||
82 | assert weights.shape == (n, ) |
||
83 | self.train_c(n, swig_ptr(x), index, swig_ptr(weights)) |
||
84 | else: |
||
85 | self.train_c(n, swig_ptr(x), index) |
||
86 | def replacement_train_encoded(self, x, codec, index, weights=None): |
||
87 | n, d = x.shape |
||
88 | assert d == codec.sa_code_size() |
||
89 | assert codec.d == index.d |
||
90 | if weights is not None: |
||
91 | assert weights.shape == (n, ) |
||
92 | self.train_encoded_c(n, swig_ptr(x), codec, index, swig_ptr(weights)) |
||
93 | else: |
||
94 | self.train_encoded_c(n, swig_ptr(x), codec, index) |
||
95 | replace_method(Clustering, 'train', replacement_train) |
||
96 | replace_method(Clustering, 'train_encoded', replacement_train_encoded) |
||
97 | |||
98 | |||
99 | handle_Clustering() |
||
100 | |||
101 | |||
102 | def handle_Quantizer(the_class): |
||
103 | |||
104 | def replacement_train(self, x): |
||
105 | n, d = x.shape |
||
106 | assert d == self.d |
||
107 | self.train_c(n, swig_ptr(x)) |
||
108 | |||
109 | def replacement_compute_codes(self, x): |
||
110 | n, d = x.shape |
||
111 | assert d == self.d |
||
112 | codes = np.empty((n, self.code_size), dtype='uint8') |
||
113 | self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n) |
||
114 | return codes |
||
115 | |||
116 | def replacement_decode(self, codes): |
||
117 | n, cs = codes.shape |
||
118 | assert cs == self.code_size |
||
119 | x = np.empty((n, self.d), dtype='float32') |
||
120 | self.decode_c(swig_ptr(codes), swig_ptr(x), n) |
||
121 | return x |
||
122 | |||
123 | replace_method(the_class, 'train', replacement_train) |
||
124 | replace_method(the_class, 'compute_codes', replacement_compute_codes) |
||
125 | replace_method(the_class, 'decode', replacement_decode) |
||
126 | |||
127 | |||
128 | handle_Quantizer(ProductQuantizer) |
||
129 | handle_Quantizer(ScalarQuantizer) |
||
130 | |||
131 | |||
132 | def handle_Index(the_class): |
||
133 | |||
134 | def replacement_add(self, x): |
||
135 | assert x.flags.contiguous |
||
136 | n, d = x.shape |
||
137 | assert d == self.d |
||
138 | self.add_c(n, swig_ptr(x)) |
||
139 | |||
140 | def replacement_add_with_ids(self, x, ids): |
||
141 | n, d = x.shape |
||
142 | assert d == self.d |
||
143 | assert ids.shape == (n, ), 'not same nb of vectors as ids' |
||
144 | self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) |
||
145 | |||
146 | def replacement_assign(self, x, k): |
||
147 | n, d = x.shape |
||
148 | assert d == self.d |
||
149 | labels = np.empty((n, k), dtype=np.int64) |
||
150 | self.assign_c(n, swig_ptr(x), swig_ptr(labels), k) |
||
151 | return labels |
||
152 | |||
153 | def replacement_train(self, x): |
||
154 | assert x.flags.contiguous |
||
155 | n, d = x.shape |
||
156 | assert d == self.d |
||
157 | self.train_c(n, swig_ptr(x)) |
||
158 | |||
159 | def replacement_search(self, x, k): |
||
160 | n, d = x.shape |
||
161 | assert d == self.d |
||
162 | distances = np.empty((n, k), dtype=np.float32) |
||
163 | labels = np.empty((n, k), dtype=np.int64) |
||
164 | self.search_c(n, swig_ptr(x), |
||
165 | k, swig_ptr(distances), |
||
166 | swig_ptr(labels)) |
||
167 | return distances, labels |
||
168 | |||
169 | def replacement_search_and_reconstruct(self, x, k): |
||
170 | n, d = x.shape |
||
171 | assert d == self.d |
||
172 | distances = np.empty((n, k), dtype=np.float32) |
||
173 | labels = np.empty((n, k), dtype=np.int64) |
||
174 | recons = np.empty((n, k, d), dtype=np.float32) |
||
175 | self.search_and_reconstruct_c(n, swig_ptr(x), |
||
176 | k, swig_ptr(distances), |
||
177 | swig_ptr(labels), |
||
178 | swig_ptr(recons)) |
||
179 | return distances, labels, recons |
||
180 | |||
181 | def replacement_remove_ids(self, x): |
||
182 | if isinstance(x, IDSelector): |
||
183 | sel = x |
||
184 | else: |
||
185 | assert x.ndim == 1 |
||
186 | index_ivf = try_extract_index_ivf (self) |
||
187 | if index_ivf and index_ivf.direct_map.type == DirectMap.Hashtable: |
||
188 | sel = IDSelectorArray(x.size, swig_ptr(x)) |
||
189 | else: |
||
190 | sel = IDSelectorBatch(x.size, swig_ptr(x)) |
||
191 | return self.remove_ids_c(sel) |
||
192 | |||
193 | def replacement_reconstruct(self, key): |
||
194 | x = np.empty(self.d, dtype=np.float32) |
||
195 | self.reconstruct_c(key, swig_ptr(x)) |
||
196 | return x |
||
197 | |||
198 | def replacement_reconstruct_n(self, n0, ni): |
||
199 | x = np.empty((ni, self.d), dtype=np.float32) |
||
200 | self.reconstruct_n_c(n0, ni, swig_ptr(x)) |
||
201 | return x |
||
202 | |||
203 | def replacement_update_vectors(self, keys, x): |
||
204 | n = keys.size |
||
205 | assert keys.shape == (n, ) |
||
206 | assert x.shape == (n, self.d) |
||
207 | self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x)) |
||
208 | |||
209 | View Code Duplication | def replacement_range_search(self, x, thresh): |
|
210 | n, d = x.shape |
||
211 | assert d == self.d |
||
212 | res = RangeSearchResult(n) |
||
213 | self.range_search_c(n, swig_ptr(x), thresh, res) |
||
214 | # get pointers and copy them |
||
215 | lims = rev_swig_ptr(res.lims, n + 1).copy() |
||
216 | nd = int(lims[-1]) |
||
217 | D = rev_swig_ptr(res.distances, nd).copy() |
||
218 | I = rev_swig_ptr(res.labels, nd).copy() |
||
219 | return lims, D, I |
||
220 | |||
221 | def replacement_sa_encode(self, x): |
||
222 | n, d = x.shape |
||
223 | assert d == self.d |
||
224 | codes = np.empty((n, self.sa_code_size()), dtype='uint8') |
||
225 | self.sa_encode_c(n, swig_ptr(x), swig_ptr(codes)) |
||
226 | return codes |
||
227 | |||
228 | def replacement_sa_decode(self, codes): |
||
229 | n, cs = codes.shape |
||
230 | assert cs == self.sa_code_size() |
||
231 | x = np.empty((n, self.d), dtype='float32') |
||
232 | self.sa_decode_c(n, swig_ptr(codes), swig_ptr(x)) |
||
233 | return x |
||
234 | |||
235 | replace_method(the_class, 'add', replacement_add) |
||
236 | replace_method(the_class, 'add_with_ids', replacement_add_with_ids) |
||
237 | replace_method(the_class, 'assign', replacement_assign) |
||
238 | replace_method(the_class, 'train', replacement_train) |
||
239 | replace_method(the_class, 'search', replacement_search) |
||
240 | replace_method(the_class, 'remove_ids', replacement_remove_ids) |
||
241 | replace_method(the_class, 'reconstruct', replacement_reconstruct) |
||
242 | replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) |
||
243 | replace_method(the_class, 'range_search', replacement_range_search) |
||
244 | replace_method(the_class, 'update_vectors', replacement_update_vectors, |
||
245 | ignore_missing=True) |
||
246 | replace_method(the_class, 'search_and_reconstruct', |
||
247 | replacement_search_and_reconstruct, ignore_missing=True) |
||
248 | replace_method(the_class, 'sa_encode', replacement_sa_encode) |
||
249 | replace_method(the_class, 'sa_decode', replacement_sa_decode) |
||
250 | |||
251 | def handle_IndexBinary(the_class): |
||
252 | |||
253 | def replacement_add(self, x): |
||
254 | assert x.flags.contiguous |
||
255 | n, d = x.shape |
||
256 | assert d * 8 == self.d |
||
257 | self.add_c(n, swig_ptr(x)) |
||
258 | |||
259 | def replacement_add_with_ids(self, x, ids): |
||
260 | n, d = x.shape |
||
261 | assert d * 8 == self.d |
||
262 | assert ids.shape == (n, ), 'not same nb of vectors as ids' |
||
263 | self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) |
||
264 | |||
265 | def replacement_train(self, x): |
||
266 | assert x.flags.contiguous |
||
267 | n, d = x.shape |
||
268 | assert d * 8 == self.d |
||
269 | self.train_c(n, swig_ptr(x)) |
||
270 | |||
271 | def replacement_reconstruct(self, key): |
||
272 | x = np.empty(self.d // 8, dtype=np.uint8) |
||
273 | self.reconstruct_c(key, swig_ptr(x)) |
||
274 | return x |
||
275 | |||
276 | def replacement_search(self, x, k): |
||
277 | n, d = x.shape |
||
278 | assert d * 8 == self.d |
||
279 | distances = np.empty((n, k), dtype=np.int32) |
||
280 | labels = np.empty((n, k), dtype=np.int64) |
||
281 | self.search_c(n, swig_ptr(x), |
||
282 | k, swig_ptr(distances), |
||
283 | swig_ptr(labels)) |
||
284 | return distances, labels |
||
285 | |||
286 | View Code Duplication | def replacement_range_search(self, x, thresh): |
|
287 | n, d = x.shape |
||
288 | assert d * 8 == self.d |
||
289 | res = RangeSearchResult(n) |
||
290 | self.range_search_c(n, swig_ptr(x), thresh, res) |
||
291 | # get pointers and copy them |
||
292 | lims = rev_swig_ptr(res.lims, n + 1).copy() |
||
293 | nd = int(lims[-1]) |
||
294 | D = rev_swig_ptr(res.distances, nd).copy() |
||
295 | I = rev_swig_ptr(res.labels, nd).copy() |
||
296 | return lims, D, I |
||
297 | |||
298 | def replacement_remove_ids(self, x): |
||
299 | if isinstance(x, IDSelector): |
||
300 | sel = x |
||
301 | else: |
||
302 | assert x.ndim == 1 |
||
303 | sel = IDSelectorBatch(x.size, swig_ptr(x)) |
||
304 | return self.remove_ids_c(sel) |
||
305 | |||
306 | replace_method(the_class, 'add', replacement_add) |
||
307 | replace_method(the_class, 'add_with_ids', replacement_add_with_ids) |
||
308 | replace_method(the_class, 'train', replacement_train) |
||
309 | replace_method(the_class, 'search', replacement_search) |
||
310 | replace_method(the_class, 'range_search', replacement_range_search) |
||
311 | replace_method(the_class, 'reconstruct', replacement_reconstruct) |
||
312 | replace_method(the_class, 'remove_ids', replacement_remove_ids) |
||
313 | |||
314 | |||
315 | def handle_VectorTransform(the_class): |
||
316 | |||
317 | def apply_method(self, x): |
||
318 | assert x.flags.contiguous |
||
319 | n, d = x.shape |
||
320 | assert d == self.d_in |
||
321 | y = np.empty((n, self.d_out), dtype=np.float32) |
||
322 | self.apply_noalloc(n, swig_ptr(x), swig_ptr(y)) |
||
323 | return y |
||
324 | |||
325 | def replacement_reverse_transform(self, x): |
||
326 | n, d = x.shape |
||
327 | assert d == self.d_out |
||
328 | y = np.empty((n, self.d_in), dtype=np.float32) |
||
329 | self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y)) |
||
330 | return y |
||
331 | |||
332 | def replacement_vt_train(self, x): |
||
333 | assert x.flags.contiguous |
||
334 | n, d = x.shape |
||
335 | assert d == self.d_in |
||
336 | self.train_c(n, swig_ptr(x)) |
||
337 | |||
338 | replace_method(the_class, 'train', replacement_vt_train) |
||
339 | # apply is reserved in Pyton... |
||
340 | the_class.apply_py = apply_method |
||
341 | replace_method(the_class, 'reverse_transform', |
||
342 | replacement_reverse_transform) |
||
343 | |||
344 | |||
345 | def handle_AutoTuneCriterion(the_class): |
||
346 | def replacement_set_groundtruth(self, D, I): |
||
347 | if D: |
||
348 | assert I.shape == D.shape |
||
349 | self.nq, self.gt_nnn = I.shape |
||
350 | self.set_groundtruth_c( |
||
351 | self.gt_nnn, swig_ptr(D) if D else None, swig_ptr(I)) |
||
352 | |||
353 | def replacement_evaluate(self, D, I): |
||
354 | assert I.shape == D.shape |
||
355 | assert I.shape == (self.nq, self.nnn) |
||
356 | return self.evaluate_c(swig_ptr(D), swig_ptr(I)) |
||
357 | |||
358 | replace_method(the_class, 'set_groundtruth', replacement_set_groundtruth) |
||
359 | replace_method(the_class, 'evaluate', replacement_evaluate) |
||
360 | |||
361 | |||
362 | def handle_ParameterSpace(the_class): |
||
363 | def replacement_explore(self, index, xq, crit): |
||
364 | assert xq.shape == (crit.nq, index.d) |
||
365 | ops = OperatingPoints() |
||
366 | self.explore_c(index, crit.nq, swig_ptr(xq), |
||
367 | crit, ops) |
||
368 | return ops |
||
369 | replace_method(the_class, 'explore', replacement_explore) |
||
370 | |||
371 | |||
372 | def handle_MatrixStats(the_class): |
||
373 | original_init = the_class.__init__ |
||
374 | |||
375 | def replacement_init(self, m): |
||
376 | assert len(m.shape) == 2 |
||
377 | original_init(self, m.shape[0], m.shape[1], swig_ptr(m)) |
||
378 | |||
379 | the_class.__init__ = replacement_init |
||
380 | |||
381 | handle_MatrixStats(MatrixStats) |
||
382 | |||
383 | |||
384 | this_module = sys.modules[__name__] |
||
385 | |||
386 | |||
387 | for symbol in dir(this_module): |
||
388 | obj = getattr(this_module, symbol) |
||
389 | # print symbol, isinstance(obj, (type, types.ClassType)) |
||
390 | if inspect.isclass(obj): |
||
391 | the_class = obj |
||
392 | if issubclass(the_class, Index): |
||
393 | handle_Index(the_class) |
||
394 | |||
395 | if issubclass(the_class, IndexBinary): |
||
396 | handle_IndexBinary(the_class) |
||
397 | |||
398 | if issubclass(the_class, VectorTransform): |
||
399 | handle_VectorTransform(the_class) |
||
400 | |||
401 | if issubclass(the_class, AutoTuneCriterion): |
||
402 | handle_AutoTuneCriterion(the_class) |
||
403 | |||
404 | if issubclass(the_class, ParameterSpace): |
||
405 | handle_ParameterSpace(the_class) |
||
406 | |||
407 | |||
408 | ########################################### |
||
409 | # Add Python references to objects |
||
410 | # we do this at the Python class wrapper level. |
||
411 | ########################################### |
||
412 | |||
413 | def add_ref_in_constructor(the_class, parameter_no): |
||
414 | # adds a reference to parameter parameter_no in self |
||
415 | # so that that parameter does not get deallocated before self |
||
416 | original_init = the_class.__init__ |
||
417 | |||
418 | def replacement_init(self, *args): |
||
419 | original_init(self, *args) |
||
420 | self.referenced_objects = [args[parameter_no]] |
||
421 | |||
422 | def replacement_init_multiple(self, *args): |
||
423 | original_init(self, *args) |
||
424 | pset = parameter_no[len(args)] |
||
425 | self.referenced_objects = [args[no] for no in pset] |
||
426 | |||
427 | if type(parameter_no) == dict: |
||
428 | # a list of parameters to keep, depending on the number of arguments |
||
429 | the_class.__init__ = replacement_init_multiple |
||
430 | else: |
||
431 | the_class.__init__ = replacement_init |
||
432 | |||
433 | def add_ref_in_method(the_class, method_name, parameter_no): |
||
434 | original_method = getattr(the_class, method_name) |
||
435 | def replacement_method(self, *args): |
||
436 | ref = args[parameter_no] |
||
437 | if not hasattr(self, 'referenced_objects'): |
||
438 | self.referenced_objects = [ref] |
||
439 | else: |
||
440 | self.referenced_objects.append(ref) |
||
441 | return original_method(self, *args) |
||
442 | setattr(the_class, method_name, replacement_method) |
||
443 | |||
444 | def add_ref_in_function(function_name, parameter_no): |
||
445 | # assumes the function returns an object |
||
446 | original_function = getattr(this_module, function_name) |
||
447 | def replacement_function(*args): |
||
448 | result = original_function(*args) |
||
449 | ref = args[parameter_no] |
||
450 | result.referenced_objects = [ref] |
||
451 | return result |
||
452 | setattr(this_module, function_name, replacement_function) |
||
453 | |||
454 | add_ref_in_constructor(IndexIVFFlat, 0) |
||
455 | add_ref_in_constructor(IndexIVFFlatDedup, 0) |
||
456 | add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]}) |
||
457 | add_ref_in_method(IndexPreTransform, 'prepend_transform', 0) |
||
458 | add_ref_in_constructor(IndexIVFPQ, 0) |
||
459 | add_ref_in_constructor(IndexIVFPQR, 0) |
||
460 | add_ref_in_constructor(Index2Layer, 0) |
||
461 | add_ref_in_constructor(Level1Quantizer, 0) |
||
462 | add_ref_in_constructor(IndexIVFScalarQuantizer, 0) |
||
463 | add_ref_in_constructor(IndexIDMap, 0) |
||
464 | add_ref_in_constructor(IndexIDMap2, 0) |
||
465 | add_ref_in_constructor(IndexHNSW, 0) |
||
466 | add_ref_in_method(IndexShards, 'add_shard', 0) |
||
467 | add_ref_in_method(IndexBinaryShards, 'add_shard', 0) |
||
468 | add_ref_in_constructor(IndexRefineFlat, 0) |
||
469 | add_ref_in_constructor(IndexBinaryIVF, 0) |
||
470 | add_ref_in_constructor(IndexBinaryFromFloat, 0) |
||
471 | add_ref_in_constructor(IndexBinaryIDMap, 0) |
||
472 | add_ref_in_constructor(IndexBinaryIDMap2, 0) |
||
473 | |||
474 | add_ref_in_method(IndexReplicas, 'addIndex', 0) |
||
475 | add_ref_in_method(IndexBinaryReplicas, 'addIndex', 0) |
||
476 | |||
477 | add_ref_in_constructor(BufferedIOWriter, 0) |
||
478 | add_ref_in_constructor(BufferedIOReader, 0) |
||
479 | |||
480 | # seems really marginal... |
||
481 | # remove_ref_from_method(IndexReplicas, 'removeIndex', 0) |
||
482 | |||
483 | if hasattr(this_module, 'GpuIndexFlat'): |
||
484 | # handle all the GPUResources refs |
||
485 | add_ref_in_function('index_cpu_to_gpu', 0) |
||
486 | add_ref_in_constructor(GpuIndexFlat, 0) |
||
487 | add_ref_in_constructor(GpuIndexFlatIP, 0) |
||
488 | add_ref_in_constructor(GpuIndexFlatL2, 0) |
||
489 | add_ref_in_constructor(GpuIndexIVFFlat, 0) |
||
490 | add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 0) |
||
491 | add_ref_in_constructor(GpuIndexIVFPQ, 0) |
||
492 | add_ref_in_constructor(GpuIndexBinaryFlat, 0) |
||
493 | |||
494 | |||
495 | |||
496 | ########################################### |
||
497 | # GPU functions |
||
498 | ########################################### |
||
499 | |||
500 | |||
501 | def index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None): |
||
502 | """ builds the C++ vectors for the GPU indices and the |
||
503 | resources. Handles the case where the resources are assigned to |
||
504 | the list of GPUs """ |
||
505 | if gpus is None: |
||
506 | gpus = range(len(resources)) |
||
507 | vres = GpuResourcesVector() |
||
508 | vdev = IntVector() |
||
509 | for i, res in zip(gpus, resources): |
||
510 | vdev.push_back(i) |
||
511 | vres.push_back(res) |
||
512 | index = index_cpu_to_gpu_multiple(vres, vdev, index, co) |
||
513 | index.referenced_objects = resources |
||
514 | return index |
||
515 | |||
516 | |||
517 | def index_cpu_to_all_gpus(index, co=None, ngpu=-1): |
||
518 | index_gpu = index_cpu_to_gpus_list(index, co=co, gpus=None, ngpu=ngpu) |
||
519 | return index_gpu |
||
520 | |||
521 | |||
522 | def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1): |
||
523 | """ Here we can pass list of GPU ids as a parameter or ngpu to |
||
524 | use first n GPU's. gpus mut be a list or None""" |
||
525 | if (gpus is None) and (ngpu == -1): # All blank |
||
526 | gpus = range(get_num_gpus()) |
||
527 | elif (gpus is None) and (ngpu != -1): # Get number of GPU's only |
||
528 | gpus = range(ngpu) |
||
529 | res = [StandardGpuResources() for _ in gpus] |
||
530 | index_gpu = index_cpu_to_gpu_multiple_py(res, index, co, gpus) |
||
531 | return index_gpu |
||
532 | |||
533 | |||
534 | ########################################### |
||
535 | # numpy array / std::vector conversions |
||
536 | ########################################### |
||
537 | |||
538 | # mapping from vector names in swigfaiss.swig and the numpy dtype names |
||
539 | vector_name_map = { |
||
540 | 'Float': 'float32', |
||
541 | 'Byte': 'uint8', |
||
542 | 'Char': 'int8', |
||
543 | 'Uint64': 'uint64', |
||
544 | 'Long': 'int64', |
||
545 | 'Int': 'int32', |
||
546 | 'Double': 'float64' |
||
547 | } |
||
548 | |||
549 | def vector_to_array(v): |
||
550 | """ convert a C++ vector to a numpy array """ |
||
551 | classname = v.__class__.__name__ |
||
552 | assert classname.endswith('Vector') |
||
553 | dtype = np.dtype(vector_name_map[classname[:-6]]) |
||
554 | a = np.empty(v.size(), dtype=dtype) |
||
555 | if v.size() > 0: |
||
556 | memcpy(swig_ptr(a), v.data(), a.nbytes) |
||
557 | return a |
||
558 | |||
559 | |||
560 | def vector_float_to_array(v): |
||
561 | return vector_to_array(v) |
||
562 | |||
563 | |||
564 | def copy_array_to_vector(a, v): |
||
565 | """ copy a numpy array to a vector """ |
||
566 | n, = a.shape |
||
567 | classname = v.__class__.__name__ |
||
568 | assert classname.endswith('Vector') |
||
569 | dtype = np.dtype(vector_name_map[classname[:-6]]) |
||
570 | assert dtype == a.dtype, ( |
||
571 | 'cannot copy a %s array to a %s (should be %s)' % ( |
||
572 | a.dtype, classname, dtype)) |
||
573 | v.resize(n) |
||
574 | if n > 0: |
||
575 | memcpy(v.data(), swig_ptr(a), a.nbytes) |
||
576 | |||
577 | |||
578 | ########################################### |
||
579 | # Wrapper for a few functions |
||
580 | ########################################### |
||
581 | |||
582 | View Code Duplication | def kmin(array, k): |
|
583 | """return k smallest values (and their indices) of the lines of a |
||
584 | float32 array""" |
||
585 | m, n = array.shape |
||
586 | I = np.zeros((m, k), dtype='int64') |
||
587 | D = np.zeros((m, k), dtype='float32') |
||
588 | ha = float_maxheap_array_t() |
||
589 | ha.ids = swig_ptr(I) |
||
590 | ha.val = swig_ptr(D) |
||
591 | ha.nh = m |
||
592 | ha.k = k |
||
593 | ha.heapify() |
||
594 | ha.addn(n, swig_ptr(array)) |
||
595 | ha.reorder() |
||
596 | return D, I |
||
597 | |||
598 | |||
599 | View Code Duplication | def kmax(array, k): |
|
600 | """return k largest values (and their indices) of the lines of a |
||
601 | float32 array""" |
||
602 | m, n = array.shape |
||
603 | I = np.zeros((m, k), dtype='int64') |
||
604 | D = np.zeros((m, k), dtype='float32') |
||
605 | ha = float_minheap_array_t() |
||
606 | ha.ids = swig_ptr(I) |
||
607 | ha.val = swig_ptr(D) |
||
608 | ha.nh = m |
||
609 | ha.k = k |
||
610 | ha.heapify() |
||
611 | ha.addn(n, swig_ptr(array)) |
||
612 | ha.reorder() |
||
613 | return D, I |
||
614 | |||
615 | |||
616 | def pairwise_distances(xq, xb, mt=METRIC_L2, metric_arg=0): |
||
617 | """compute the whole pairwise distance matrix between two sets of |
||
618 | vectors""" |
||
619 | nq, d = xq.shape |
||
620 | nb, d2 = xb.shape |
||
621 | assert d == d2 |
||
622 | dis = np.empty((nq, nb), dtype='float32') |
||
623 | if mt == METRIC_L2: |
||
624 | pairwise_L2sqr( |
||
625 | d, nq, swig_ptr(xq), |
||
626 | nb, swig_ptr(xb), |
||
627 | swig_ptr(dis)) |
||
628 | else: |
||
629 | pairwise_extra_distances( |
||
630 | d, nq, swig_ptr(xq), |
||
631 | nb, swig_ptr(xb), |
||
632 | mt, metric_arg, |
||
633 | swig_ptr(dis)) |
||
634 | return dis |
||
635 | |||
636 | |||
637 | |||
638 | |||
639 | def rand(n, seed=12345): |
||
640 | res = np.empty(n, dtype='float32') |
||
641 | float_rand(swig_ptr(res), res.size, seed) |
||
642 | return res |
||
643 | |||
644 | |||
645 | def randint(n, seed=12345, vmax=None): |
||
646 | res = np.empty(n, dtype='int64') |
||
647 | if vmax is None: |
||
648 | int64_rand(swig_ptr(res), res.size, seed) |
||
649 | else: |
||
650 | int64_rand_max(swig_ptr(res), res.size, vmax, seed) |
||
651 | return res |
||
652 | |||
653 | lrand = randint |
||
654 | |||
655 | def randn(n, seed=12345): |
||
656 | res = np.empty(n, dtype='float32') |
||
657 | float_randn(swig_ptr(res), res.size, seed) |
||
658 | return res |
||
659 | |||
660 | |||
661 | def eval_intersection(I1, I2): |
||
662 | """ size of intersection between each line of two result tables""" |
||
663 | n = I1.shape[0] |
||
664 | assert I2.shape[0] == n |
||
665 | k1, k2 = I1.shape[1], I2.shape[1] |
||
666 | ninter = 0 |
||
667 | for i in range(n): |
||
668 | ninter += ranklist_intersection_size( |
||
669 | k1, swig_ptr(I1[i]), k2, swig_ptr(I2[i])) |
||
670 | return ninter |
||
671 | |||
672 | |||
673 | def normalize_L2(x): |
||
674 | fvec_renorm_L2(x.shape[1], x.shape[0], swig_ptr(x)) |
||
675 | |||
676 | # MapLong2Long interface |
||
677 | |||
678 | def replacement_map_add(self, keys, vals): |
||
679 | n, = keys.shape |
||
680 | assert (n,) == keys.shape |
||
681 | self.add_c(n, swig_ptr(keys), swig_ptr(vals)) |
||
682 | |||
683 | def replacement_map_search_multiple(self, keys): |
||
684 | n, = keys.shape |
||
685 | vals = np.empty(n, dtype='int64') |
||
686 | self.search_multiple_c(n, swig_ptr(keys), swig_ptr(vals)) |
||
687 | return vals |
||
688 | |||
689 | replace_method(MapLong2Long, 'add', replacement_map_add) |
||
690 | replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple) |
||
691 | |||
692 | |||
693 | ########################################### |
||
694 | # Kmeans object |
||
695 | ########################################### |
||
696 | |||
697 | |||
698 | class Kmeans: |
||
699 | """shallow wrapper around the Clustering object. The important method |
||
700 | is train().""" |
||
701 | |||
702 | def __init__(self, d, k, **kwargs): |
||
703 | """d: input dimension, k: nb of centroids. Additional |
||
704 | parameters are passed on the ClusteringParameters object, |
||
705 | including niter=25, verbose=False, spherical = False |
||
706 | """ |
||
707 | self.d = d |
||
708 | self.k = k |
||
709 | self.gpu = False |
||
710 | self.cp = ClusteringParameters() |
||
711 | for k, v in kwargs.items(): |
||
712 | if k == 'gpu': |
||
713 | self.gpu = v |
||
714 | else: |
||
715 | # if this raises an exception, it means that it is a non-existent field |
||
716 | getattr(self.cp, k) |
||
717 | setattr(self.cp, k, v) |
||
718 | self.centroids = None |
||
719 | |||
720 | def train(self, x, weights=None): |
||
721 | n, d = x.shape |
||
722 | assert d == self.d |
||
723 | clus = Clustering(d, self.k, self.cp) |
||
724 | if self.cp.spherical: |
||
725 | self.index = IndexFlatIP(d) |
||
726 | else: |
||
727 | self.index = IndexFlatL2(d) |
||
728 | if self.gpu: |
||
729 | if self.gpu == True: |
||
730 | ngpu = -1 |
||
731 | else: |
||
732 | ngpu = self.gpu |
||
733 | self.index = index_cpu_to_all_gpus(self.index, ngpu=ngpu) |
||
734 | clus.train(x, self.index, weights) |
||
735 | centroids = vector_float_to_array(clus.centroids) |
||
736 | self.centroids = centroids.reshape(self.k, d) |
||
737 | stats = clus.iteration_stats |
||
738 | self.obj = np.array([ |
||
739 | stats.at(i).obj for i in range(stats.size()) |
||
740 | ]) |
||
741 | return self.obj[-1] if self.obj.size > 0 else 0.0 |
||
742 | |||
743 | def assign(self, x): |
||
744 | assert self.centroids is not None, "should train before assigning" |
||
745 | self.index.reset() |
||
746 | self.index.add(self.centroids) |
||
747 | D, I = self.index.search(x, 1) |
||
748 | return D.ravel(), I.ravel() |
||
749 | |||
750 | # IndexProxy was renamed to IndexReplicas, remap the old name for any old code |
||
751 | # people may have |
||
752 | IndexProxy = IndexReplicas |
||
753 | ConcatenatedInvertedLists = HStackInvertedLists |
||
754 | |||
755 | ########################################### |
||
756 | # serialization of indexes to byte arrays |
||
757 | ########################################### |
||
758 | |||
759 | def serialize_index(index): |
||
760 | """ convert an index to a numpy uint8 array """ |
||
761 | writer = VectorIOWriter() |
||
762 | write_index(index, writer) |
||
763 | return vector_to_array(writer.data) |
||
764 | |||
765 | def deserialize_index(data): |
||
766 | reader = VectorIOReader() |
||
767 | copy_array_to_vector(data, reader.data) |
||
768 | return read_index(reader) |
||
769 | |||
770 | def serialize_index_binary(index): |
||
771 | """ convert an index to a numpy uint8 array """ |
||
772 | writer = VectorIOWriter() |
||
773 | write_index_binary(index, writer) |
||
774 | return vector_to_array(writer.data) |
||
775 | |||
776 | def deserialize_index_binary(data): |
||
777 | reader = VectorIOReader() |
||
778 | copy_array_to_vector(data, reader.data) |
||
779 | return read_index_binary(reader) |
||
780 | |||
781 | |||
782 | ########################################### |
||
783 | # ResultHeap |
||
784 | ########################################### |
||
785 | |||
786 | View Code Duplication | class ResultHeap: |
|
787 | """Accumulate query results from a sliced dataset. The final result will |
||
788 | be in self.D, self.I.""" |
||
789 | |||
790 | def __init__(self, nq, k): |
||
791 | " nq: number of query vectors, k: number of results per query " |
||
792 | self.I = np.zeros((nq, k), dtype='int64') |
||
793 | self.D = np.zeros((nq, k), dtype='float32') |
||
794 | self.nq, self.k = nq, k |
||
795 | heaps = float_maxheap_array_t() |
||
796 | heaps.k = k |
||
797 | heaps.nh = nq |
||
798 | heaps.val = swig_ptr(self.D) |
||
799 | heaps.ids = swig_ptr(self.I) |
||
800 | heaps.heapify() |
||
801 | self.heaps = heaps |
||
802 | |||
803 | def add_result(self, D, I): |
||
804 | """D, I do not need to be in a particular order (heap or sorted)""" |
||
805 | assert D.shape == (self.nq, self.k) |
||
806 | assert I.shape == (self.nq, self.k) |
||
807 | self.heaps.addn_with_ids( |
||
808 | self.k, faiss.swig_ptr(D), |
||
809 | faiss.swig_ptr(I), self.k) |
||
810 | |||
811 | def finalize(self): |
||
812 | self.heaps.reorder() |
||
813 |