| Total Complexity | 73 |
| Total Lines | 813 |
| Duplicated Lines | 9.72 % |
| Changes | 0 | ||
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like faiss often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # Copyright (c) Facebook, Inc. and its affiliates. |
||
| 2 | # |
||
| 3 | # This source code is licensed under the MIT license found in the |
||
| 4 | # LICENSE file in the root directory of this source tree. |
||
| 5 | |||
| 6 | #@nolint |
||
| 7 | |||
| 8 | # not linting this file because it imports * form swigfaiss, which |
||
| 9 | # causes a ton of useless warnings. |
||
| 10 | |||
| 11 | import numpy as np |
||
| 12 | import sys |
||
| 13 | import inspect |
||
| 14 | import pdb |
||
| 15 | import platform |
||
| 16 | import subprocess |
||
| 17 | import logging |
||
| 18 | |||
| 19 | |||
| 20 | logger = logging.getLogger(__name__) |
||
| 21 | |||
| 22 | |||
| 23 | def instruction_set(): |
||
| 24 | if platform.system() == "Darwin": |
||
| 25 | if subprocess.check_output(["/usr/sbin/sysctl", "hw.optional.avx2_0"])[-1] == '1': |
||
| 26 | return "AVX2" |
||
| 27 | else: |
||
| 28 | return "default" |
||
| 29 | elif platform.system() == "Linux": |
||
| 30 | import numpy.distutils.cpuinfo |
||
| 31 | if "avx2" in numpy.distutils.cpuinfo.cpu.info[0].get('flags', ""): |
||
| 32 | return "AVX2" |
||
| 33 | else: |
||
| 34 | return "default" |
||
| 35 | |||
| 36 | |||
| 37 | try: |
||
| 38 | instr_set = instruction_set() |
||
| 39 | if instr_set == "AVX2": |
||
| 40 | logger.info("Loading faiss with AVX2 support.") |
||
| 41 | from .swigfaiss_avx2 import * |
||
| 42 | else: |
||
| 43 | logger.info("Loading faiss.") |
||
| 44 | from .swigfaiss import * |
||
| 45 | |||
| 46 | except ImportError: |
||
| 47 | # we import * so that the symbol X can be accessed as faiss.X |
||
| 48 | logger.info("Loading faiss.") |
||
| 49 | from .swigfaiss import * |
||
| 50 | |||
| 51 | |||
| 52 | __version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR, |
||
|
|
|||
| 53 | FAISS_VERSION_MINOR, |
||
| 54 | FAISS_VERSION_PATCH) |
||
| 55 | |||
| 56 | ################################################################## |
||
| 57 | # The functions below add or replace some methods for classes |
||
| 58 | # this is to be able to pass in numpy arrays directly |
||
| 59 | # The C++ version of the classnames will be suffixed with _c |
||
| 60 | ################################################################## |
||
| 61 | |||
| 62 | |||
| 63 | def replace_method(the_class, name, replacement, ignore_missing=False): |
||
| 64 | try: |
||
| 65 | orig_method = getattr(the_class, name) |
||
| 66 | except AttributeError: |
||
| 67 | if ignore_missing: |
||
| 68 | return |
||
| 69 | raise |
||
| 70 | if orig_method.__name__ == 'replacement_' + name: |
||
| 71 | # replacement was done in parent class |
||
| 72 | return |
||
| 73 | setattr(the_class, name + '_c', orig_method) |
||
| 74 | setattr(the_class, name, replacement) |
||
| 75 | |||
| 76 | |||
| 77 | def handle_Clustering(): |
||
| 78 | def replacement_train(self, x, index, weights=None): |
||
| 79 | n, d = x.shape |
||
| 80 | assert d == self.d |
||
| 81 | if weights is not None: |
||
| 82 | assert weights.shape == (n, ) |
||
| 83 | self.train_c(n, swig_ptr(x), index, swig_ptr(weights)) |
||
| 84 | else: |
||
| 85 | self.train_c(n, swig_ptr(x), index) |
||
| 86 | def replacement_train_encoded(self, x, codec, index, weights=None): |
||
| 87 | n, d = x.shape |
||
| 88 | assert d == codec.sa_code_size() |
||
| 89 | assert codec.d == index.d |
||
| 90 | if weights is not None: |
||
| 91 | assert weights.shape == (n, ) |
||
| 92 | self.train_encoded_c(n, swig_ptr(x), codec, index, swig_ptr(weights)) |
||
| 93 | else: |
||
| 94 | self.train_encoded_c(n, swig_ptr(x), codec, index) |
||
| 95 | replace_method(Clustering, 'train', replacement_train) |
||
| 96 | replace_method(Clustering, 'train_encoded', replacement_train_encoded) |
||
| 97 | |||
| 98 | |||
| 99 | handle_Clustering() |
||
| 100 | |||
| 101 | |||
| 102 | def handle_Quantizer(the_class): |
||
| 103 | |||
| 104 | def replacement_train(self, x): |
||
| 105 | n, d = x.shape |
||
| 106 | assert d == self.d |
||
| 107 | self.train_c(n, swig_ptr(x)) |
||
| 108 | |||
| 109 | def replacement_compute_codes(self, x): |
||
| 110 | n, d = x.shape |
||
| 111 | assert d == self.d |
||
| 112 | codes = np.empty((n, self.code_size), dtype='uint8') |
||
| 113 | self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n) |
||
| 114 | return codes |
||
| 115 | |||
| 116 | def replacement_decode(self, codes): |
||
| 117 | n, cs = codes.shape |
||
| 118 | assert cs == self.code_size |
||
| 119 | x = np.empty((n, self.d), dtype='float32') |
||
| 120 | self.decode_c(swig_ptr(codes), swig_ptr(x), n) |
||
| 121 | return x |
||
| 122 | |||
| 123 | replace_method(the_class, 'train', replacement_train) |
||
| 124 | replace_method(the_class, 'compute_codes', replacement_compute_codes) |
||
| 125 | replace_method(the_class, 'decode', replacement_decode) |
||
| 126 | |||
| 127 | |||
| 128 | handle_Quantizer(ProductQuantizer) |
||
| 129 | handle_Quantizer(ScalarQuantizer) |
||
| 130 | |||
| 131 | |||
| 132 | def handle_Index(the_class): |
||
| 133 | |||
| 134 | def replacement_add(self, x): |
||
| 135 | assert x.flags.contiguous |
||
| 136 | n, d = x.shape |
||
| 137 | assert d == self.d |
||
| 138 | self.add_c(n, swig_ptr(x)) |
||
| 139 | |||
| 140 | def replacement_add_with_ids(self, x, ids): |
||
| 141 | n, d = x.shape |
||
| 142 | assert d == self.d |
||
| 143 | assert ids.shape == (n, ), 'not same nb of vectors as ids' |
||
| 144 | self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) |
||
| 145 | |||
| 146 | def replacement_assign(self, x, k): |
||
| 147 | n, d = x.shape |
||
| 148 | assert d == self.d |
||
| 149 | labels = np.empty((n, k), dtype=np.int64) |
||
| 150 | self.assign_c(n, swig_ptr(x), swig_ptr(labels), k) |
||
| 151 | return labels |
||
| 152 | |||
| 153 | def replacement_train(self, x): |
||
| 154 | assert x.flags.contiguous |
||
| 155 | n, d = x.shape |
||
| 156 | assert d == self.d |
||
| 157 | self.train_c(n, swig_ptr(x)) |
||
| 158 | |||
| 159 | def replacement_search(self, x, k): |
||
| 160 | n, d = x.shape |
||
| 161 | assert d == self.d |
||
| 162 | distances = np.empty((n, k), dtype=np.float32) |
||
| 163 | labels = np.empty((n, k), dtype=np.int64) |
||
| 164 | self.search_c(n, swig_ptr(x), |
||
| 165 | k, swig_ptr(distances), |
||
| 166 | swig_ptr(labels)) |
||
| 167 | return distances, labels |
||
| 168 | |||
| 169 | def replacement_search_and_reconstruct(self, x, k): |
||
| 170 | n, d = x.shape |
||
| 171 | assert d == self.d |
||
| 172 | distances = np.empty((n, k), dtype=np.float32) |
||
| 173 | labels = np.empty((n, k), dtype=np.int64) |
||
| 174 | recons = np.empty((n, k, d), dtype=np.float32) |
||
| 175 | self.search_and_reconstruct_c(n, swig_ptr(x), |
||
| 176 | k, swig_ptr(distances), |
||
| 177 | swig_ptr(labels), |
||
| 178 | swig_ptr(recons)) |
||
| 179 | return distances, labels, recons |
||
| 180 | |||
| 181 | def replacement_remove_ids(self, x): |
||
| 182 | if isinstance(x, IDSelector): |
||
| 183 | sel = x |
||
| 184 | else: |
||
| 185 | assert x.ndim == 1 |
||
| 186 | index_ivf = try_extract_index_ivf (self) |
||
| 187 | if index_ivf and index_ivf.direct_map.type == DirectMap.Hashtable: |
||
| 188 | sel = IDSelectorArray(x.size, swig_ptr(x)) |
||
| 189 | else: |
||
| 190 | sel = IDSelectorBatch(x.size, swig_ptr(x)) |
||
| 191 | return self.remove_ids_c(sel) |
||
| 192 | |||
| 193 | def replacement_reconstruct(self, key): |
||
| 194 | x = np.empty(self.d, dtype=np.float32) |
||
| 195 | self.reconstruct_c(key, swig_ptr(x)) |
||
| 196 | return x |
||
| 197 | |||
| 198 | def replacement_reconstruct_n(self, n0, ni): |
||
| 199 | x = np.empty((ni, self.d), dtype=np.float32) |
||
| 200 | self.reconstruct_n_c(n0, ni, swig_ptr(x)) |
||
| 201 | return x |
||
| 202 | |||
| 203 | def replacement_update_vectors(self, keys, x): |
||
| 204 | n = keys.size |
||
| 205 | assert keys.shape == (n, ) |
||
| 206 | assert x.shape == (n, self.d) |
||
| 207 | self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x)) |
||
| 208 | |||
| 209 | View Code Duplication | def replacement_range_search(self, x, thresh): |
|
| 210 | n, d = x.shape |
||
| 211 | assert d == self.d |
||
| 212 | res = RangeSearchResult(n) |
||
| 213 | self.range_search_c(n, swig_ptr(x), thresh, res) |
||
| 214 | # get pointers and copy them |
||
| 215 | lims = rev_swig_ptr(res.lims, n + 1).copy() |
||
| 216 | nd = int(lims[-1]) |
||
| 217 | D = rev_swig_ptr(res.distances, nd).copy() |
||
| 218 | I = rev_swig_ptr(res.labels, nd).copy() |
||
| 219 | return lims, D, I |
||
| 220 | |||
| 221 | def replacement_sa_encode(self, x): |
||
| 222 | n, d = x.shape |
||
| 223 | assert d == self.d |
||
| 224 | codes = np.empty((n, self.sa_code_size()), dtype='uint8') |
||
| 225 | self.sa_encode_c(n, swig_ptr(x), swig_ptr(codes)) |
||
| 226 | return codes |
||
| 227 | |||
| 228 | def replacement_sa_decode(self, codes): |
||
| 229 | n, cs = codes.shape |
||
| 230 | assert cs == self.sa_code_size() |
||
| 231 | x = np.empty((n, self.d), dtype='float32') |
||
| 232 | self.sa_decode_c(n, swig_ptr(codes), swig_ptr(x)) |
||
| 233 | return x |
||
| 234 | |||
| 235 | replace_method(the_class, 'add', replacement_add) |
||
| 236 | replace_method(the_class, 'add_with_ids', replacement_add_with_ids) |
||
| 237 | replace_method(the_class, 'assign', replacement_assign) |
||
| 238 | replace_method(the_class, 'train', replacement_train) |
||
| 239 | replace_method(the_class, 'search', replacement_search) |
||
| 240 | replace_method(the_class, 'remove_ids', replacement_remove_ids) |
||
| 241 | replace_method(the_class, 'reconstruct', replacement_reconstruct) |
||
| 242 | replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) |
||
| 243 | replace_method(the_class, 'range_search', replacement_range_search) |
||
| 244 | replace_method(the_class, 'update_vectors', replacement_update_vectors, |
||
| 245 | ignore_missing=True) |
||
| 246 | replace_method(the_class, 'search_and_reconstruct', |
||
| 247 | replacement_search_and_reconstruct, ignore_missing=True) |
||
| 248 | replace_method(the_class, 'sa_encode', replacement_sa_encode) |
||
| 249 | replace_method(the_class, 'sa_decode', replacement_sa_decode) |
||
| 250 | |||
| 251 | def handle_IndexBinary(the_class): |
||
| 252 | |||
| 253 | def replacement_add(self, x): |
||
| 254 | assert x.flags.contiguous |
||
| 255 | n, d = x.shape |
||
| 256 | assert d * 8 == self.d |
||
| 257 | self.add_c(n, swig_ptr(x)) |
||
| 258 | |||
| 259 | def replacement_add_with_ids(self, x, ids): |
||
| 260 | n, d = x.shape |
||
| 261 | assert d * 8 == self.d |
||
| 262 | assert ids.shape == (n, ), 'not same nb of vectors as ids' |
||
| 263 | self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) |
||
| 264 | |||
| 265 | def replacement_train(self, x): |
||
| 266 | assert x.flags.contiguous |
||
| 267 | n, d = x.shape |
||
| 268 | assert d * 8 == self.d |
||
| 269 | self.train_c(n, swig_ptr(x)) |
||
| 270 | |||
| 271 | def replacement_reconstruct(self, key): |
||
| 272 | x = np.empty(self.d // 8, dtype=np.uint8) |
||
| 273 | self.reconstruct_c(key, swig_ptr(x)) |
||
| 274 | return x |
||
| 275 | |||
| 276 | def replacement_search(self, x, k): |
||
| 277 | n, d = x.shape |
||
| 278 | assert d * 8 == self.d |
||
| 279 | distances = np.empty((n, k), dtype=np.int32) |
||
| 280 | labels = np.empty((n, k), dtype=np.int64) |
||
| 281 | self.search_c(n, swig_ptr(x), |
||
| 282 | k, swig_ptr(distances), |
||
| 283 | swig_ptr(labels)) |
||
| 284 | return distances, labels |
||
| 285 | |||
| 286 | View Code Duplication | def replacement_range_search(self, x, thresh): |
|
| 287 | n, d = x.shape |
||
| 288 | assert d * 8 == self.d |
||
| 289 | res = RangeSearchResult(n) |
||
| 290 | self.range_search_c(n, swig_ptr(x), thresh, res) |
||
| 291 | # get pointers and copy them |
||
| 292 | lims = rev_swig_ptr(res.lims, n + 1).copy() |
||
| 293 | nd = int(lims[-1]) |
||
| 294 | D = rev_swig_ptr(res.distances, nd).copy() |
||
| 295 | I = rev_swig_ptr(res.labels, nd).copy() |
||
| 296 | return lims, D, I |
||
| 297 | |||
| 298 | def replacement_remove_ids(self, x): |
||
| 299 | if isinstance(x, IDSelector): |
||
| 300 | sel = x |
||
| 301 | else: |
||
| 302 | assert x.ndim == 1 |
||
| 303 | sel = IDSelectorBatch(x.size, swig_ptr(x)) |
||
| 304 | return self.remove_ids_c(sel) |
||
| 305 | |||
| 306 | replace_method(the_class, 'add', replacement_add) |
||
| 307 | replace_method(the_class, 'add_with_ids', replacement_add_with_ids) |
||
| 308 | replace_method(the_class, 'train', replacement_train) |
||
| 309 | replace_method(the_class, 'search', replacement_search) |
||
| 310 | replace_method(the_class, 'range_search', replacement_range_search) |
||
| 311 | replace_method(the_class, 'reconstruct', replacement_reconstruct) |
||
| 312 | replace_method(the_class, 'remove_ids', replacement_remove_ids) |
||
| 313 | |||
| 314 | |||
| 315 | def handle_VectorTransform(the_class): |
||
| 316 | |||
| 317 | def apply_method(self, x): |
||
| 318 | assert x.flags.contiguous |
||
| 319 | n, d = x.shape |
||
| 320 | assert d == self.d_in |
||
| 321 | y = np.empty((n, self.d_out), dtype=np.float32) |
||
| 322 | self.apply_noalloc(n, swig_ptr(x), swig_ptr(y)) |
||
| 323 | return y |
||
| 324 | |||
| 325 | def replacement_reverse_transform(self, x): |
||
| 326 | n, d = x.shape |
||
| 327 | assert d == self.d_out |
||
| 328 | y = np.empty((n, self.d_in), dtype=np.float32) |
||
| 329 | self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y)) |
||
| 330 | return y |
||
| 331 | |||
| 332 | def replacement_vt_train(self, x): |
||
| 333 | assert x.flags.contiguous |
||
| 334 | n, d = x.shape |
||
| 335 | assert d == self.d_in |
||
| 336 | self.train_c(n, swig_ptr(x)) |
||
| 337 | |||
| 338 | replace_method(the_class, 'train', replacement_vt_train) |
||
| 339 | # apply is reserved in Pyton... |
||
| 340 | the_class.apply_py = apply_method |
||
| 341 | replace_method(the_class, 'reverse_transform', |
||
| 342 | replacement_reverse_transform) |
||
| 343 | |||
| 344 | |||
| 345 | def handle_AutoTuneCriterion(the_class): |
||
| 346 | def replacement_set_groundtruth(self, D, I): |
||
| 347 | if D: |
||
| 348 | assert I.shape == D.shape |
||
| 349 | self.nq, self.gt_nnn = I.shape |
||
| 350 | self.set_groundtruth_c( |
||
| 351 | self.gt_nnn, swig_ptr(D) if D else None, swig_ptr(I)) |
||
| 352 | |||
| 353 | def replacement_evaluate(self, D, I): |
||
| 354 | assert I.shape == D.shape |
||
| 355 | assert I.shape == (self.nq, self.nnn) |
||
| 356 | return self.evaluate_c(swig_ptr(D), swig_ptr(I)) |
||
| 357 | |||
| 358 | replace_method(the_class, 'set_groundtruth', replacement_set_groundtruth) |
||
| 359 | replace_method(the_class, 'evaluate', replacement_evaluate) |
||
| 360 | |||
| 361 | |||
| 362 | def handle_ParameterSpace(the_class): |
||
| 363 | def replacement_explore(self, index, xq, crit): |
||
| 364 | assert xq.shape == (crit.nq, index.d) |
||
| 365 | ops = OperatingPoints() |
||
| 366 | self.explore_c(index, crit.nq, swig_ptr(xq), |
||
| 367 | crit, ops) |
||
| 368 | return ops |
||
| 369 | replace_method(the_class, 'explore', replacement_explore) |
||
| 370 | |||
| 371 | |||
| 372 | def handle_MatrixStats(the_class): |
||
| 373 | original_init = the_class.__init__ |
||
| 374 | |||
| 375 | def replacement_init(self, m): |
||
| 376 | assert len(m.shape) == 2 |
||
| 377 | original_init(self, m.shape[0], m.shape[1], swig_ptr(m)) |
||
| 378 | |||
| 379 | the_class.__init__ = replacement_init |
||
| 380 | |||
| 381 | handle_MatrixStats(MatrixStats) |
||
| 382 | |||
| 383 | |||
| 384 | this_module = sys.modules[__name__] |
||
| 385 | |||
| 386 | |||
| 387 | for symbol in dir(this_module): |
||
| 388 | obj = getattr(this_module, symbol) |
||
| 389 | # print symbol, isinstance(obj, (type, types.ClassType)) |
||
| 390 | if inspect.isclass(obj): |
||
| 391 | the_class = obj |
||
| 392 | if issubclass(the_class, Index): |
||
| 393 | handle_Index(the_class) |
||
| 394 | |||
| 395 | if issubclass(the_class, IndexBinary): |
||
| 396 | handle_IndexBinary(the_class) |
||
| 397 | |||
| 398 | if issubclass(the_class, VectorTransform): |
||
| 399 | handle_VectorTransform(the_class) |
||
| 400 | |||
| 401 | if issubclass(the_class, AutoTuneCriterion): |
||
| 402 | handle_AutoTuneCriterion(the_class) |
||
| 403 | |||
| 404 | if issubclass(the_class, ParameterSpace): |
||
| 405 | handle_ParameterSpace(the_class) |
||
| 406 | |||
| 407 | |||
| 408 | ########################################### |
||
| 409 | # Add Python references to objects |
||
| 410 | # we do this at the Python class wrapper level. |
||
| 411 | ########################################### |
||
| 412 | |||
| 413 | def add_ref_in_constructor(the_class, parameter_no): |
||
| 414 | # adds a reference to parameter parameter_no in self |
||
| 415 | # so that that parameter does not get deallocated before self |
||
| 416 | original_init = the_class.__init__ |
||
| 417 | |||
| 418 | def replacement_init(self, *args): |
||
| 419 | original_init(self, *args) |
||
| 420 | self.referenced_objects = [args[parameter_no]] |
||
| 421 | |||
| 422 | def replacement_init_multiple(self, *args): |
||
| 423 | original_init(self, *args) |
||
| 424 | pset = parameter_no[len(args)] |
||
| 425 | self.referenced_objects = [args[no] for no in pset] |
||
| 426 | |||
| 427 | if type(parameter_no) == dict: |
||
| 428 | # a list of parameters to keep, depending on the number of arguments |
||
| 429 | the_class.__init__ = replacement_init_multiple |
||
| 430 | else: |
||
| 431 | the_class.__init__ = replacement_init |
||
| 432 | |||
| 433 | def add_ref_in_method(the_class, method_name, parameter_no): |
||
| 434 | original_method = getattr(the_class, method_name) |
||
| 435 | def replacement_method(self, *args): |
||
| 436 | ref = args[parameter_no] |
||
| 437 | if not hasattr(self, 'referenced_objects'): |
||
| 438 | self.referenced_objects = [ref] |
||
| 439 | else: |
||
| 440 | self.referenced_objects.append(ref) |
||
| 441 | return original_method(self, *args) |
||
| 442 | setattr(the_class, method_name, replacement_method) |
||
| 443 | |||
| 444 | def add_ref_in_function(function_name, parameter_no): |
||
| 445 | # assumes the function returns an object |
||
| 446 | original_function = getattr(this_module, function_name) |
||
| 447 | def replacement_function(*args): |
||
| 448 | result = original_function(*args) |
||
| 449 | ref = args[parameter_no] |
||
| 450 | result.referenced_objects = [ref] |
||
| 451 | return result |
||
| 452 | setattr(this_module, function_name, replacement_function) |
||
| 453 | |||
| 454 | add_ref_in_constructor(IndexIVFFlat, 0) |
||
| 455 | add_ref_in_constructor(IndexIVFFlatDedup, 0) |
||
| 456 | add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]}) |
||
| 457 | add_ref_in_method(IndexPreTransform, 'prepend_transform', 0) |
||
| 458 | add_ref_in_constructor(IndexIVFPQ, 0) |
||
| 459 | add_ref_in_constructor(IndexIVFPQR, 0) |
||
| 460 | add_ref_in_constructor(Index2Layer, 0) |
||
| 461 | add_ref_in_constructor(Level1Quantizer, 0) |
||
| 462 | add_ref_in_constructor(IndexIVFScalarQuantizer, 0) |
||
| 463 | add_ref_in_constructor(IndexIDMap, 0) |
||
| 464 | add_ref_in_constructor(IndexIDMap2, 0) |
||
| 465 | add_ref_in_constructor(IndexHNSW, 0) |
||
| 466 | add_ref_in_method(IndexShards, 'add_shard', 0) |
||
| 467 | add_ref_in_method(IndexBinaryShards, 'add_shard', 0) |
||
| 468 | add_ref_in_constructor(IndexRefineFlat, 0) |
||
| 469 | add_ref_in_constructor(IndexBinaryIVF, 0) |
||
| 470 | add_ref_in_constructor(IndexBinaryFromFloat, 0) |
||
| 471 | add_ref_in_constructor(IndexBinaryIDMap, 0) |
||
| 472 | add_ref_in_constructor(IndexBinaryIDMap2, 0) |
||
| 473 | |||
| 474 | add_ref_in_method(IndexReplicas, 'addIndex', 0) |
||
| 475 | add_ref_in_method(IndexBinaryReplicas, 'addIndex', 0) |
||
| 476 | |||
| 477 | add_ref_in_constructor(BufferedIOWriter, 0) |
||
| 478 | add_ref_in_constructor(BufferedIOReader, 0) |
||
| 479 | |||
| 480 | # seems really marginal... |
||
| 481 | # remove_ref_from_method(IndexReplicas, 'removeIndex', 0) |
||
| 482 | |||
| 483 | if hasattr(this_module, 'GpuIndexFlat'): |
||
| 484 | # handle all the GPUResources refs |
||
| 485 | add_ref_in_function('index_cpu_to_gpu', 0) |
||
| 486 | add_ref_in_constructor(GpuIndexFlat, 0) |
||
| 487 | add_ref_in_constructor(GpuIndexFlatIP, 0) |
||
| 488 | add_ref_in_constructor(GpuIndexFlatL2, 0) |
||
| 489 | add_ref_in_constructor(GpuIndexIVFFlat, 0) |
||
| 490 | add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 0) |
||
| 491 | add_ref_in_constructor(GpuIndexIVFPQ, 0) |
||
| 492 | add_ref_in_constructor(GpuIndexBinaryFlat, 0) |
||
| 493 | |||
| 494 | |||
| 495 | |||
| 496 | ########################################### |
||
| 497 | # GPU functions |
||
| 498 | ########################################### |
||
| 499 | |||
| 500 | |||
| 501 | def index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None): |
||
| 502 | """ builds the C++ vectors for the GPU indices and the |
||
| 503 | resources. Handles the case where the resources are assigned to |
||
| 504 | the list of GPUs """ |
||
| 505 | if gpus is None: |
||
| 506 | gpus = range(len(resources)) |
||
| 507 | vres = GpuResourcesVector() |
||
| 508 | vdev = IntVector() |
||
| 509 | for i, res in zip(gpus, resources): |
||
| 510 | vdev.push_back(i) |
||
| 511 | vres.push_back(res) |
||
| 512 | index = index_cpu_to_gpu_multiple(vres, vdev, index, co) |
||
| 513 | index.referenced_objects = resources |
||
| 514 | return index |
||
| 515 | |||
| 516 | |||
| 517 | def index_cpu_to_all_gpus(index, co=None, ngpu=-1): |
||
| 518 | index_gpu = index_cpu_to_gpus_list(index, co=co, gpus=None, ngpu=ngpu) |
||
| 519 | return index_gpu |
||
| 520 | |||
| 521 | |||
| 522 | def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1): |
||
| 523 | """ Here we can pass list of GPU ids as a parameter or ngpu to |
||
| 524 | use first n GPU's. gpus mut be a list or None""" |
||
| 525 | if (gpus is None) and (ngpu == -1): # All blank |
||
| 526 | gpus = range(get_num_gpus()) |
||
| 527 | elif (gpus is None) and (ngpu != -1): # Get number of GPU's only |
||
| 528 | gpus = range(ngpu) |
||
| 529 | res = [StandardGpuResources() for _ in gpus] |
||
| 530 | index_gpu = index_cpu_to_gpu_multiple_py(res, index, co, gpus) |
||
| 531 | return index_gpu |
||
| 532 | |||
| 533 | |||
| 534 | ########################################### |
||
| 535 | # numpy array / std::vector conversions |
||
| 536 | ########################################### |
||
| 537 | |||
| 538 | # mapping from vector names in swigfaiss.swig and the numpy dtype names |
||
| 539 | vector_name_map = { |
||
| 540 | 'Float': 'float32', |
||
| 541 | 'Byte': 'uint8', |
||
| 542 | 'Char': 'int8', |
||
| 543 | 'Uint64': 'uint64', |
||
| 544 | 'Long': 'int64', |
||
| 545 | 'Int': 'int32', |
||
| 546 | 'Double': 'float64' |
||
| 547 | } |
||
| 548 | |||
| 549 | def vector_to_array(v): |
||
| 550 | """ convert a C++ vector to a numpy array """ |
||
| 551 | classname = v.__class__.__name__ |
||
| 552 | assert classname.endswith('Vector') |
||
| 553 | dtype = np.dtype(vector_name_map[classname[:-6]]) |
||
| 554 | a = np.empty(v.size(), dtype=dtype) |
||
| 555 | if v.size() > 0: |
||
| 556 | memcpy(swig_ptr(a), v.data(), a.nbytes) |
||
| 557 | return a |
||
| 558 | |||
| 559 | |||
| 560 | def vector_float_to_array(v): |
||
| 561 | return vector_to_array(v) |
||
| 562 | |||
| 563 | |||
| 564 | def copy_array_to_vector(a, v): |
||
| 565 | """ copy a numpy array to a vector """ |
||
| 566 | n, = a.shape |
||
| 567 | classname = v.__class__.__name__ |
||
| 568 | assert classname.endswith('Vector') |
||
| 569 | dtype = np.dtype(vector_name_map[classname[:-6]]) |
||
| 570 | assert dtype == a.dtype, ( |
||
| 571 | 'cannot copy a %s array to a %s (should be %s)' % ( |
||
| 572 | a.dtype, classname, dtype)) |
||
| 573 | v.resize(n) |
||
| 574 | if n > 0: |
||
| 575 | memcpy(v.data(), swig_ptr(a), a.nbytes) |
||
| 576 | |||
| 577 | |||
| 578 | ########################################### |
||
| 579 | # Wrapper for a few functions |
||
| 580 | ########################################### |
||
| 581 | |||
| 582 | View Code Duplication | def kmin(array, k): |
|
| 583 | """return k smallest values (and their indices) of the lines of a |
||
| 584 | float32 array""" |
||
| 585 | m, n = array.shape |
||
| 586 | I = np.zeros((m, k), dtype='int64') |
||
| 587 | D = np.zeros((m, k), dtype='float32') |
||
| 588 | ha = float_maxheap_array_t() |
||
| 589 | ha.ids = swig_ptr(I) |
||
| 590 | ha.val = swig_ptr(D) |
||
| 591 | ha.nh = m |
||
| 592 | ha.k = k |
||
| 593 | ha.heapify() |
||
| 594 | ha.addn(n, swig_ptr(array)) |
||
| 595 | ha.reorder() |
||
| 596 | return D, I |
||
| 597 | |||
| 598 | |||
| 599 | View Code Duplication | def kmax(array, k): |
|
| 600 | """return k largest values (and their indices) of the lines of a |
||
| 601 | float32 array""" |
||
| 602 | m, n = array.shape |
||
| 603 | I = np.zeros((m, k), dtype='int64') |
||
| 604 | D = np.zeros((m, k), dtype='float32') |
||
| 605 | ha = float_minheap_array_t() |
||
| 606 | ha.ids = swig_ptr(I) |
||
| 607 | ha.val = swig_ptr(D) |
||
| 608 | ha.nh = m |
||
| 609 | ha.k = k |
||
| 610 | ha.heapify() |
||
| 611 | ha.addn(n, swig_ptr(array)) |
||
| 612 | ha.reorder() |
||
| 613 | return D, I |
||
| 614 | |||
| 615 | |||
| 616 | def pairwise_distances(xq, xb, mt=METRIC_L2, metric_arg=0): |
||
| 617 | """compute the whole pairwise distance matrix between two sets of |
||
| 618 | vectors""" |
||
| 619 | nq, d = xq.shape |
||
| 620 | nb, d2 = xb.shape |
||
| 621 | assert d == d2 |
||
| 622 | dis = np.empty((nq, nb), dtype='float32') |
||
| 623 | if mt == METRIC_L2: |
||
| 624 | pairwise_L2sqr( |
||
| 625 | d, nq, swig_ptr(xq), |
||
| 626 | nb, swig_ptr(xb), |
||
| 627 | swig_ptr(dis)) |
||
| 628 | else: |
||
| 629 | pairwise_extra_distances( |
||
| 630 | d, nq, swig_ptr(xq), |
||
| 631 | nb, swig_ptr(xb), |
||
| 632 | mt, metric_arg, |
||
| 633 | swig_ptr(dis)) |
||
| 634 | return dis |
||
| 635 | |||
| 636 | |||
| 637 | |||
| 638 | |||
| 639 | def rand(n, seed=12345): |
||
| 640 | res = np.empty(n, dtype='float32') |
||
| 641 | float_rand(swig_ptr(res), res.size, seed) |
||
| 642 | return res |
||
| 643 | |||
| 644 | |||
| 645 | def randint(n, seed=12345, vmax=None): |
||
| 646 | res = np.empty(n, dtype='int64') |
||
| 647 | if vmax is None: |
||
| 648 | int64_rand(swig_ptr(res), res.size, seed) |
||
| 649 | else: |
||
| 650 | int64_rand_max(swig_ptr(res), res.size, vmax, seed) |
||
| 651 | return res |
||
| 652 | |||
| 653 | lrand = randint |
||
| 654 | |||
| 655 | def randn(n, seed=12345): |
||
| 656 | res = np.empty(n, dtype='float32') |
||
| 657 | float_randn(swig_ptr(res), res.size, seed) |
||
| 658 | return res |
||
| 659 | |||
| 660 | |||
| 661 | def eval_intersection(I1, I2): |
||
| 662 | """ size of intersection between each line of two result tables""" |
||
| 663 | n = I1.shape[0] |
||
| 664 | assert I2.shape[0] == n |
||
| 665 | k1, k2 = I1.shape[1], I2.shape[1] |
||
| 666 | ninter = 0 |
||
| 667 | for i in range(n): |
||
| 668 | ninter += ranklist_intersection_size( |
||
| 669 | k1, swig_ptr(I1[i]), k2, swig_ptr(I2[i])) |
||
| 670 | return ninter |
||
| 671 | |||
| 672 | |||
| 673 | def normalize_L2(x): |
||
| 674 | fvec_renorm_L2(x.shape[1], x.shape[0], swig_ptr(x)) |
||
| 675 | |||
| 676 | # MapLong2Long interface |
||
| 677 | |||
| 678 | def replacement_map_add(self, keys, vals): |
||
| 679 | n, = keys.shape |
||
| 680 | assert (n,) == keys.shape |
||
| 681 | self.add_c(n, swig_ptr(keys), swig_ptr(vals)) |
||
| 682 | |||
| 683 | def replacement_map_search_multiple(self, keys): |
||
| 684 | n, = keys.shape |
||
| 685 | vals = np.empty(n, dtype='int64') |
||
| 686 | self.search_multiple_c(n, swig_ptr(keys), swig_ptr(vals)) |
||
| 687 | return vals |
||
| 688 | |||
| 689 | replace_method(MapLong2Long, 'add', replacement_map_add) |
||
| 690 | replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple) |
||
| 691 | |||
| 692 | |||
| 693 | ########################################### |
||
| 694 | # Kmeans object |
||
| 695 | ########################################### |
||
| 696 | |||
| 697 | |||
| 698 | class Kmeans: |
||
| 699 | """shallow wrapper around the Clustering object. The important method |
||
| 700 | is train().""" |
||
| 701 | |||
| 702 | def __init__(self, d, k, **kwargs): |
||
| 703 | """d: input dimension, k: nb of centroids. Additional |
||
| 704 | parameters are passed on the ClusteringParameters object, |
||
| 705 | including niter=25, verbose=False, spherical = False |
||
| 706 | """ |
||
| 707 | self.d = d |
||
| 708 | self.k = k |
||
| 709 | self.gpu = False |
||
| 710 | self.cp = ClusteringParameters() |
||
| 711 | for k, v in kwargs.items(): |
||
| 712 | if k == 'gpu': |
||
| 713 | self.gpu = v |
||
| 714 | else: |
||
| 715 | # if this raises an exception, it means that it is a non-existent field |
||
| 716 | getattr(self.cp, k) |
||
| 717 | setattr(self.cp, k, v) |
||
| 718 | self.centroids = None |
||
| 719 | |||
| 720 | def train(self, x, weights=None): |
||
| 721 | n, d = x.shape |
||
| 722 | assert d == self.d |
||
| 723 | clus = Clustering(d, self.k, self.cp) |
||
| 724 | if self.cp.spherical: |
||
| 725 | self.index = IndexFlatIP(d) |
||
| 726 | else: |
||
| 727 | self.index = IndexFlatL2(d) |
||
| 728 | if self.gpu: |
||
| 729 | if self.gpu == True: |
||
| 730 | ngpu = -1 |
||
| 731 | else: |
||
| 732 | ngpu = self.gpu |
||
| 733 | self.index = index_cpu_to_all_gpus(self.index, ngpu=ngpu) |
||
| 734 | clus.train(x, self.index, weights) |
||
| 735 | centroids = vector_float_to_array(clus.centroids) |
||
| 736 | self.centroids = centroids.reshape(self.k, d) |
||
| 737 | stats = clus.iteration_stats |
||
| 738 | self.obj = np.array([ |
||
| 739 | stats.at(i).obj for i in range(stats.size()) |
||
| 740 | ]) |
||
| 741 | return self.obj[-1] if self.obj.size > 0 else 0.0 |
||
| 742 | |||
| 743 | def assign(self, x): |
||
| 744 | assert self.centroids is not None, "should train before assigning" |
||
| 745 | self.index.reset() |
||
| 746 | self.index.add(self.centroids) |
||
| 747 | D, I = self.index.search(x, 1) |
||
| 748 | return D.ravel(), I.ravel() |
||
| 749 | |||
| 750 | # IndexProxy was renamed to IndexReplicas, remap the old name for any old code |
||
| 751 | # people may have |
||
| 752 | IndexProxy = IndexReplicas |
||
| 753 | ConcatenatedInvertedLists = HStackInvertedLists |
||
| 754 | |||
| 755 | ########################################### |
||
| 756 | # serialization of indexes to byte arrays |
||
| 757 | ########################################### |
||
| 758 | |||
| 759 | def serialize_index(index): |
||
| 760 | """ convert an index to a numpy uint8 array """ |
||
| 761 | writer = VectorIOWriter() |
||
| 762 | write_index(index, writer) |
||
| 763 | return vector_to_array(writer.data) |
||
| 764 | |||
| 765 | def deserialize_index(data): |
||
| 766 | reader = VectorIOReader() |
||
| 767 | copy_array_to_vector(data, reader.data) |
||
| 768 | return read_index(reader) |
||
| 769 | |||
| 770 | def serialize_index_binary(index): |
||
| 771 | """ convert an index to a numpy uint8 array """ |
||
| 772 | writer = VectorIOWriter() |
||
| 773 | write_index_binary(index, writer) |
||
| 774 | return vector_to_array(writer.data) |
||
| 775 | |||
| 776 | def deserialize_index_binary(data): |
||
| 777 | reader = VectorIOReader() |
||
| 778 | copy_array_to_vector(data, reader.data) |
||
| 779 | return read_index_binary(reader) |
||
| 780 | |||
| 781 | |||
| 782 | ########################################### |
||
| 783 | # ResultHeap |
||
| 784 | ########################################### |
||
| 785 | |||
| 786 | View Code Duplication | class ResultHeap: |
|
| 787 | """Accumulate query results from a sliced dataset. The final result will |
||
| 788 | be in self.D, self.I.""" |
||
| 789 | |||
| 790 | def __init__(self, nq, k): |
||
| 791 | " nq: number of query vectors, k: number of results per query " |
||
| 792 | self.I = np.zeros((nq, k), dtype='int64') |
||
| 793 | self.D = np.zeros((nq, k), dtype='float32') |
||
| 794 | self.nq, self.k = nq, k |
||
| 795 | heaps = float_maxheap_array_t() |
||
| 796 | heaps.k = k |
||
| 797 | heaps.nh = nq |
||
| 798 | heaps.val = swig_ptr(self.D) |
||
| 799 | heaps.ids = swig_ptr(self.I) |
||
| 800 | heaps.heapify() |
||
| 801 | self.heaps = heaps |
||
| 802 | |||
| 803 | def add_result(self, D, I): |
||
| 804 | """D, I do not need to be in a particular order (heap or sorted)""" |
||
| 805 | assert D.shape == (self.nq, self.k) |
||
| 806 | assert I.shape == (self.nq, self.k) |
||
| 807 | self.heaps.addn_with_ids( |
||
| 808 | self.k, faiss.swig_ptr(D), |
||
| 809 | faiss.swig_ptr(I), self.k) |
||
| 810 | |||
| 811 | def finalize(self): |
||
| 812 | self.heaps.reorder() |
||
| 813 |