Passed
Push — master ( 596409...1042f2 )
by
unknown
02:06
created

utils.gen_invalid_term()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 0
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import os
2
import sys
3
import random
4
import pdb
5
import string
6
import struct
7
import logging
8
import time, datetime
9
import copy
10
import numpy as np
11
from sklearn import preprocessing
12
from milvus import Milvus, DataType
13
14
port = 19530
15
epsilon = 0.000001
16
default_flush_interval = 1
17
big_flush_interval = 1000
18
dimension = 128
19
nb = 6000
20
top_k = 10
21
segment_row_count = 5000
22
default_float_vec_field_name = "float_vector"
23
default_binary_vec_field_name = "binary_vector"
24
25
# TODO:
26
all_index_types = [
27
    "FLAT",
28
    "IVF_FLAT",
29
    "IVF_SQ8",
30
    "IVF_SQ8_HYBRID",
31
    "IVF_PQ",
32
    "HNSW",
33
    # "NSG",
34
    "ANNOY",
35
    "BIN_FLAT",
36
    "BIN_IVF_FLAT"
37
]
38
39
default_index_params = [
40
    {"nlist": 1024},
41
    {"nlist": 1024},
42
    {"nlist": 1024},
43
    {"nlist": 1024},
44
    {"nlist": 1024, "m": 16},
45
    {"M": 48, "efConstruction": 500},
46
    # {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
47
    {"n_trees": 4},
48
    {"nlist": 1024},
49
    {"nlist": 1024}
50
]
51
52
53
def index_cpu_not_support():
54
    return ["IVF_SQ8_HYBRID"]
55
56
57
def binary_support():
58
    return ["BIN_FLAT", "BIN_IVF_FLAT"]
59
60
61
def delete_support():
62
    return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
63
64
65
def ivf():
66
    return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
67
68
69
def l2(x, y):
70
    return np.linalg.norm(np.array(x) - np.array(y))
71
72
73
def ip(x, y):
74
    return np.inner(np.array(x), np.array(y))
75
76
77
def jaccard(x, y):
78
    x = np.asarray(x, np.bool)
79
    y = np.asarray(y, np.bool)
80
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum())
81
82
83
def hamming(x, y):
84
    x = np.asarray(x, np.bool)
85
    y = np.asarray(y, np.bool)
86
    return np.bitwise_xor(x, y).sum()
87
88
89
def tanimoto(x, y):
90
    x = np.asarray(x, np.bool)
91
    y = np.asarray(y, np.bool)
92
    return -np.log2(np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum()))
93
94
95
def substructure(x, y):
96
    x = np.asarray(x, np.bool)
97
    y = np.asarray(y, np.bool)
98
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(y)
99
100
101
def superstructure(x, y):
102
    x = np.asarray(x, np.bool)
103
    y = np.asarray(y, np.bool)
104
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(x)
105
106
107
def get_milvus(host, port, uri=None, handler=None, **kwargs):
108
    if handler is None:
109
        handler = "GRPC"
110
    try_connect = kwargs.get("try_connect", True)
111
    if uri is not None:
112
        milvus = Milvus(uri=uri, handler=handler, try_connect=try_connect)
113
    else:
114
        milvus = Milvus(host=host, port=port, handler=handler, try_connect=try_connect)
115
    return milvus
116
117
118
def disable_flush(connect):
119
    connect.set_config("storage", "auto_flush_interval", big_flush_interval)
120
121
122
def enable_flush(connect):
123
    # reset auto_flush_interval=1
124
    connect.set_config("storage", "auto_flush_interval", default_flush_interval)
125
    config_value = connect.get_config("storage", "auto_flush_interval")
126
    assert config_value == str(default_flush_interval)
127
128
129
def gen_inaccuracy(num):
130
    return num / 255.0
131
132
133
def gen_vectors(num, dim, is_normal=True):
134
    vectors = [[random.random() for _ in range(dim)] for _ in range(num)]
135
    vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
136
    return vectors.tolist()
137
138
139
# def gen_vectors(num, dim, seed=np.random.RandomState(1234), is_normal=False):
140
#     xb = seed.rand(num, dim).astype("float32")
141
#     xb = preprocessing.normalize(xb, axis=1, norm='l2')
142
#     return xb.tolist()
143
144
145
def gen_binary_vectors(num, dim):
146
    raw_vectors = []
147
    binary_vectors = []
148
    for i in range(num):
149
        raw_vector = [random.randint(0, 1) for i in range(dim)]
150
        raw_vectors.append(raw_vector)
151
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
152
    return raw_vectors, binary_vectors
153
154
155
def gen_binary_sub_vectors(vectors, length):
156
    raw_vectors = []
157
    binary_vectors = []
158
    dim = len(vectors[0])
159
    for i in range(length):
160
        raw_vector = [0 for i in range(dim)]
161
        vector = vectors[i]
162
        for index, j in enumerate(vector):
163
            if j == 1:
164
                raw_vector[index] = 1
165
        raw_vectors.append(raw_vector)
166
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
167
    return raw_vectors, binary_vectors
168
169
170
def gen_binary_super_vectors(vectors, length):
171
    raw_vectors = []
172
    binary_vectors = []
173
    dim = len(vectors[0])
174
    for i in range(length):
175
        cnt_1 = np.count_nonzero(vectors[i])
176
        raw_vector = [1 for i in range(dim)]
177
        raw_vectors.append(raw_vector)
178
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
179
    return raw_vectors, binary_vectors
180
181
182
def gen_int_attr(row_num):
183
    return [random.randint(0, 255) for _ in range(row_num)]
184
185
186
def gen_float_attr(row_num):
187
    return [random.uniform(0, 255) for _ in range(row_num)]
188
189
190
def gen_unique_str(str_value=None):
191
    prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
192
    return "test_" + prefix if str_value is None else str_value + "_" + prefix
193
194
195
def gen_single_filter_fields():
196
    fields = []
197
    for data_type in DataType:
198
        if data_type in [DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]:
199
            fields.append({"field": data_type.name, "type": data_type})
200
    return fields
201
202
203
def gen_single_vector_fields():
204
    fields = []
205
    for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
206
        field = {"field": data_type.name, "type": data_type, "params": {"dim": dimension}}
207
        fields.append(field)
208
    return fields
209
210
211
def gen_default_fields(auto_id=False):
212
    default_fields = {
213
        "fields": [
214
            {"field": "int64", "type": DataType.INT64},
215
            {"field": "float", "type": DataType.FLOAT},
216
            {"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": dimension}},
217
        ],
218
        "segment_row_count": segment_row_count,
219
        "auto_id" : True
220
    }
221
    if auto_id is True:
222
        default_fields["auto_id"] = True
223
    return default_fields
224
225
226
def gen_binary_default_fields(auto_id=False):
227
    default_fields = {
228
        "fields": [
229
            {"field": "int64", "type": DataType.INT64},
230
            {"field": "float", "type": DataType.FLOAT},
231
            {"field": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": dimension}}
232
        ],
233
        "segment_row_count": segment_row_count
234
    }
235
    if auto_id is True:
236
        default_fields["auto_id"] = True
237
    return default_fields
238
239
240
def gen_entities(nb, is_normal=False):
241
    vectors = gen_vectors(nb, dimension, is_normal)
242
    entities = [
243
        {"field": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
244
        {"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
245
        {"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
246
    ]
247
    return entities
248
249
250
def gen_binary_entities(nb):
251
    raw_vectors, vectors = gen_binary_vectors(nb, dimension)
252
    entities = [
253
        {"field": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
254
        {"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
255
        {"field": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "values": vectors}
256
    ]
257
    return raw_vectors, entities
258
259
260
def gen_entities_by_fields(fields, nb, dimension):
261
    entities = []
262
    for field in fields:
263
        if field["type"] in [DataType.INT32, DataType.INT64]:
264
            field_value = [1 for i in range(nb)]
265
        elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]:
266
            field_value = [3.0 for i in range(nb)]
267
        elif field["type"] == DataType.BINARY_VECTOR:
268
            field_value = gen_binary_vectors(nb, dimension)[1]
269
        elif field["type"] == DataType.FLOAT_VECTOR:
270
            field_value = gen_vectors(nb, dimension)
271
        field.update({"values": field_value})
0 ignored issues
show
introduced by
The variable field_value does not seem to be defined for all execution paths.
Loading history...
272
        entities.append(field)
273
    return entities
274
275
276
def assert_equal_entity(a, b):
277
    pass
278
279
280
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
281
                      metric_type=None):
282
    if rand_vector is True:
283
        dimension = len(entities[-1]["values"][0])
284
        query_vectors = gen_vectors(nq, dimension)
285
    else:
286
        query_vectors = entities[-1]["values"][:nq]
287
    must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}}
288
    if metric_type is not None:
289
        must_param["vector"][field_name]["metric_type"] = metric_type
290
    query = {
291
        "bool": {
292
            "must": [must_param]
293
        }
294
    }
295
    return query, query_vectors
296
297
298
def update_query_expr(src_query, keep_old=True, expr=None):
299
    tmp_query = copy.deepcopy(src_query)
300
    if expr is not None:
301
        tmp_query["bool"].update(expr)
302
    if keep_old is not True:
303
        tmp_query["bool"].pop("must")
304
    return tmp_query
305
306
307
def gen_default_vector_expr(default_query):
308
    return default_query["bool"]["must"][0]
309
310
311
def gen_default_term_expr(keyword="term", values=None):
312
    if values is None:
313
        values = [i for i in range(nb // 2)]
314
    expr = {keyword: {"int64": {"values": values}}}
315
    return expr
316
317
318
def gen_default_range_expr(keyword="range", ranges=None):
319
    if ranges is None:
320
        ranges = {"GT": 1, "LT": nb // 2}
321
    expr = {keyword: {"int64": {"ranges": ranges}}}
322
    return expr
323
324
325
def gen_invalid_range():
326
    range = [
327
        {"range": 1},
328
        {"range": {}},
329
        {"range": []},
330
        {"range": {"range": {"int64": {"ranges": {"GT": 0, "LT": nb//2}}}}}
331
    ]
332
    return range
333
334
335
def gen_invalid_ranges():
336
    ranges = [
337
        {"GT": nb, "LT": 0},
338
        {"GT": nb},
339
        {"LT": 0},
340
        {"GT": 0.0, "LT": float(nb)}
341
    ]
342
    return ranges
343
344
345
def gen_valid_ranges():
346
    ranges = [
347
        {"GT": 0, "LT": nb//2},
348
        {"GT": nb, "LT": nb*2},
349
        {"GT": 0},
350
        {"LT": nb},
351
        {"GT": -1, "LT": top_k},
352
    ]
353
    return ranges
354
355
356
def gen_invalid_term():
357
    terms = [
358
        {"term": 1},
359
        {"term": []},
360
        {"term": {"term": {"int64": {"values": [i for i in range(nb//2)]}}}}
361
    ]
362
    return terms
363
364
365
def add_field_default(default_fields, type=DataType.INT64, field_name=None):
366
    tmp_fields = copy.deepcopy(default_fields)
367
    if field_name is None:
368
        field_name = gen_unique_str()
369
    field = {
370
        "field": field_name,
371
        "type": type
372
    }
373
    tmp_fields["fields"].append(field)
374
    return tmp_fields
375
376
377
def add_field(entities, field_name=None):
378
    nb = len(entities[0]["values"])
379
    tmp_entities = copy.deepcopy(entities)
380
    if field_name is None:
381
        field_name = gen_unique_str()
382
    field = {
383
        "field": field_name,
384
        "type": DataType.INT64,
385
        "values": [i for i in range(nb)]
386
    }
387
    tmp_entities.append(field)
388
    return tmp_entities
389
390
391
def add_vector_field(entities, is_normal=False):
392
    nb = len(entities[0]["values"])
393
    vectors = gen_vectors(nb, dimension, is_normal)
394
    field = {
395
        "field": gen_unique_str(),
396
        "type": DataType.FLOAT_VECTOR,
397
        "values": vectors
398
    }
399
    entities.append(field)
400
    return entities
401
402
403
# def update_fields_metric_type(fields, metric_type):
404
#     tmp_fields = copy.deepcopy(fields)
405
#     if metric_type in ["L2", "IP"]:
406
#         tmp_fields["fields"][-1]["type"] = DataType.FLOAT_VECTOR
407
#     else:
408
#         tmp_fields["fields"][-1]["type"] = DataType.BINARY_VECTOR
409
#     tmp_fields["fields"][-1]["params"]["metric_type"] = metric_type
410
#     return tmp_fields
411
412
413
def remove_field(entities):
414
    del entities[0]
415
    return entities
416
417
418
def remove_vector_field(entities):
419
    del entities[-1]
420
    return entities
421
422
423
def update_field_name(entities, old_name, new_name):
424
    for item in entities:
425
        if item["field"] == old_name:
426
            item["field"] = new_name
427
    return entities
428
429
430
def update_field_type(entities, old_name, new_name):
431
    for item in entities:
432
        if item["field"] == old_name:
433
            item["type"] = new_name
434
    return entities
435
436
437
def update_field_value(entities, old_type, new_value):
438
    for item in entities:
439
        if item["type"] == old_type:
440
            for i in item["values"]:
441
                item["values"][i] = new_value
442
    return entities
443
444
445
def add_vector_field(nb, dimension=dimension):
446
    field_name = gen_unique_str()
447
    field = {
448
        "field": field_name,
449
        "type": DataType.FLOAT_VECTOR,
450
        "values": gen_vectors(nb, dimension)
451
    }
452
    return field_name
453
454
455
def gen_segment_row_counts():
456
    sizes = [
457
        1,
458
        2,
459
        1024,
460
        4096
461
    ]
462
    return sizes
463
464
465
def gen_invalid_ips():
466
    ips = [
467
        # "255.0.0.0",
468
        # "255.255.0.0",
469
        # "255.255.255.0",
470
        # "255.255.255.255",
471
        "127.0.0",
472
        # "123.0.0.2",
473
        "12-s",
474
        " ",
475
        "12 s",
476
        "BB。A",
477
        " siede ",
478
        "(mn)",
479
        "中文",
480
        "a".join("a" for _ in range(256))
481
    ]
482
    return ips
483
484
485
def gen_invalid_uris():
486
    ip = None
487
    uris = [
488
        " ",
489
        "中文",
490
        # invalid protocol
491
        # "tc://%s:%s" % (ip, port),
492
        # "tcp%s:%s" % (ip, port),
493
494
        # # invalid port
495
        # "tcp://%s:100000" % ip,
496
        # "tcp://%s: " % ip,
497
        # "tcp://%s:19540" % ip,
498
        # "tcp://%s:-1" % ip,
499
        # "tcp://%s:string" % ip,
500
501
        # invalid ip
502
        "tcp:// :19530",
503
        # "tcp://123.0.0.1:%s" % port,
504
        "tcp://127.0.0:19530",
505
        # "tcp://255.0.0.0:%s" % port,
506
        # "tcp://255.255.0.0:%s" % port,
507
        # "tcp://255.255.255.0:%s" % port,
508
        # "tcp://255.255.255.255:%s" % port,
509
        "tcp://\n:19530",
510
    ]
511
    return uris
512
513
514
def gen_invalid_strs():
515
    strings = [
516
        1,
517
        [1],
518
        None,
519
        "12-s",
520
        " ",
521
        # "",
522
        # None,
523
        "12 s",
524
        "BB。A",
525
        "c|c",
526
        " siede ",
527
        "(mn)",
528
        "pip+",
529
        "=c",
530
        "中文",
531
        "a".join("a" for i in range(256))
532
    ]
533
    return strings
534
535
536
def gen_invalid_field_types():
537
    field_types = [
538
        # 1,
539
        "=c",
540
        # 0,
541
        None,
542
        "",
543
        "a".join("a" for i in range(256))
544
    ]
545
    return field_types
546
547
548
def gen_invalid_metric_types():
549
    metric_types = [
550
        1,
551
        "=c",
552
        0,
553
        None,
554
        "",
555
        "a".join("a" for i in range(256))
556
    ]
557
    return metric_types
558
559
560
# TODO:
561
def gen_invalid_ints():
562
    top_ks = [
563
        # 1.0,
564
        None,
565
        "stringg",
566
        [1, 2, 3],
567
        (1, 2),
568
        {"a": 1},
569
        " ",
570
        "",
571
        "String",
572
        "12-s",
573
        "BB。A",
574
        " siede ",
575
        "(mn)",
576
        "pip+",
577
        "=c",
578
        "中文",
579
        "a".join("a" for i in range(256))
580
    ]
581
    return top_ks
582
583
584
def gen_invalid_params():
585
    params = [
586
        9999999999,
587
        -1,
588
        # None,
589
        [1, 2, 3],
590
        (1, 2),
591
        {"a": 1},
592
        " ",
593
        "",
594
        "String",
595
        "12-s",
596
        "BB。A",
597
        " siede ",
598
        "(mn)",
599
        "pip+",
600
        "=c",
601
        "中文"
602
    ]
603
    return params
604
605
606
def gen_invalid_vectors():
607
    invalid_vectors = [
608
        "1*2",
609
        [],
610
        [1],
611
        [1, 2],
612
        [" "],
613
        ['a'],
614
        [None],
615
        None,
616
        (1, 2),
617
        {"a": 1},
618
        " ",
619
        "",
620
        "String",
621
        "12-s",
622
        "BB。A",
623
        " siede ",
624
        "(mn)",
625
        "pip+",
626
        "=c",
627
        "中文",
628
        "a".join("a" for i in range(256))
629
    ]
630
    return invalid_vectors
631
632
633
def gen_invaild_search_params():
634
    invalid_search_key = 100
635
    search_params = []
636
    for index_type in all_index_types:
637
        if index_type == "FLAT":
638
            continue
639
        search_params.append({"index_type": index_type, "search_params": {"invalid_key": invalid_search_key}})
640
        if index_type in delete_support():
641
            for nprobe in gen_invalid_params():
642
                ivf_search_params = {"index_type": index_type, "search_params": {"nprobe": nprobe}}
643
                search_params.append(ivf_search_params)
644
        elif index_type == "HNSW":
645
            for ef in gen_invalid_params():
646
                hnsw_search_param = {"index_type": index_type, "search_params": {"ef": ef}}
647
                search_params.append(hnsw_search_param)
648
        elif index_type == "NSG":
649
            for search_length in gen_invalid_params():
650
                nsg_search_param = {"index_type": index_type, "search_params": {"search_length": search_length}}
651
                search_params.append(nsg_search_param)
652
            search_params.append({"index_type": index_type, "search_params": {"invalid_key": 100}})
653
        elif index_type == "ANNOY":
654
            for search_k in gen_invalid_params():
655
                if isinstance(search_k, int):
656
                    continue
657
                annoy_search_param = {"index_type": index_type, "search_params": {"search_k": search_k}}
658
                search_params.append(annoy_search_param)
659
    return search_params
660
661
662
def gen_invalid_index():
663
    index_params = []
664
    for index_type in gen_invalid_strs():
665
        index_param = {"index_type": index_type, "params": {"nlist": 1024}}
666
        index_params.append(index_param)
667
    for nlist in gen_invalid_params():
668
        index_param = {"index_type": "IVF_FLAT", "params": {"nlist": nlist}}
669
        index_params.append(index_param)
670
    for M in gen_invalid_params():
671
        index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}}
672
        index_params.append(index_param)
673
    for efConstruction in gen_invalid_params():
674
        index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}}
675
        index_params.append(index_param)
676
    for search_length in gen_invalid_params():
677
        index_param = {"index_type": "NSG",
678
                       "params": {"search_length": search_length, "out_degree": 40, "candidate_pool_size": 50,
679
                                  "knng": 100}}
680
        index_params.append(index_param)
681
    for out_degree in gen_invalid_params():
682
        index_param = {"index_type": "NSG",
683
                       "params": {"search_length": 100, "out_degree": out_degree, "candidate_pool_size": 50,
684
                                  "knng": 100}}
685
        index_params.append(index_param)
686
    for candidate_pool_size in gen_invalid_params():
687
        index_param = {"index_type": "NSG", "params": {"search_length": 100, "out_degree": 40,
688
                                                       "candidate_pool_size": candidate_pool_size,
689
                                                       "knng": 100}}
690
        index_params.append(index_param)
691
    index_params.append({"index_type": "IVF_FLAT", "params": {"invalid_key": 1024}})
692
    index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}})
693
    index_params.append({"index_type": "NSG",
694
                         "params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
695
                                    "knng": 100}})
696
    for invalid_n_trees in gen_invalid_params():
697
        index_params.append({"index_type": "ANNOY", "params": {"n_trees": invalid_n_trees}})
698
699
    return index_params
700
701
702
def gen_index():
703
    nlists = [1, 1024, 16384]
704
    pq_ms = [128, 64, 32, 16, 8, 4]
705
    Ms = [5, 24, 48]
706
    efConstructions = [100, 300, 500]
707
    search_lengths = [10, 100, 300]
708
    out_degrees = [5, 40, 300]
709
    candidate_pool_sizes = [50, 100, 300]
710
    knngs = [5, 100, 300]
711
712
    index_params = []
713
    for index_type in all_index_types:
714
        if index_type in ["FLAT", "BIN_FLAT", "BIN_IVF_FLAT"]:
715
            index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}})
716
        elif index_type in ["IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID"]:
717
            ivf_params = [{"index_type": index_type, "index_param": {"nlist": nlist}} \
718
                          for nlist in nlists]
719
            index_params.extend(ivf_params)
720
        elif index_type == "IVF_PQ":
721
            IVFPQ_params = [{"index_type": index_type, "index_param": {"nlist": nlist, "m": m}} \
722
                            for nlist in nlists \
723
                            for m in pq_ms]
724
            index_params.extend(IVFPQ_params)
725
        elif index_type == "HNSW":
726
            hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \
727
                           for M in Ms \
728
                           for efConstruction in efConstructions]
729
            index_params.extend(hnsw_params)
730
        elif index_type == "NSG":
731
            nsg_params = [{"index_type": index_type,
732
                           "index_param": {"search_length": search_length, "out_degree": out_degree,
733
                                           "candidate_pool_size": candidate_pool_size, "knng": knng}} \
734
                          for search_length in search_lengths \
735
                          for out_degree in out_degrees \
736
                          for candidate_pool_size in candidate_pool_sizes \
737
                          for knng in knngs]
738
            index_params.extend(nsg_params)
739
740
    return index_params
741
742
743
def gen_simple_index():
744
    index_params = []
745
    for i in range(len(all_index_types)):
746
        if all_index_types[i] in binary_support():
747
            continue
748
        dic = {"index_type": all_index_types[i], "metric_type": "L2"}
749
        dic.update({"params": default_index_params[i]})
750
        index_params.append(dic)
751
    return index_params
752
753
754
def gen_binary_index():
755
    index_params = []
756
    for i in range(len(all_index_types)):
757
        if all_index_types[i] in binary_support():
758
            dic = {"index_type": all_index_types[i]}
759
            dic.update({"params": default_index_params[i]})
760
            index_params.append(dic)
761
    return index_params
762
763
764
def get_search_param(index_type):
765
    search_params = {"metric_type": "L2"}
766
    if index_type in ivf() or index_type in binary_support():
767
        search_params.update({"nprobe": 32})
768
    elif index_type == "HNSW":
769
        search_params.update({"ef": 64})
770
    elif index_type == "NSG":
771
        search_params.update({"search_length": 100})
772
    elif index_type == "ANNOY":
773
        search_params.update({"search_k": 100})
774
    else:
775
        logging.getLogger().error("Invalid index_type.")
776
        raise Exception("Invalid index_type.")
777
    return search_params
778
779
780
def assert_equal_vector(v1, v2):
781
    if len(v1) != len(v2):
782
        assert False
783
    for i in range(len(v1)):
784
        assert abs(v1[i] - v2[i]) < epsilon
785
786
787
def restart_server(helm_release_name):
788
    res = True
789
    timeout = 120
790
    from kubernetes import client, config
791
    client.rest.logger.setLevel(logging.WARNING)
792
793
    namespace = "milvus"
794
    # service_name = "%s.%s.svc.cluster.local" % (helm_release_name, namespace)
795
    config.load_kube_config()
796
    v1 = client.CoreV1Api()
797
    pod_name = None
798
    # config_map_names = v1.list_namespaced_config_map(namespace, pretty='true')
799
    # body = {"replicas": 0}
800
    pods = v1.list_namespaced_pod(namespace)
801
    for i in pods.items:
802
        if i.metadata.name.find(helm_release_name) != -1 and i.metadata.name.find("mysql") == -1:
803
            pod_name = i.metadata.name
804
            break
805
            # v1.patch_namespaced_config_map(config_map_name, namespace, body, pretty='true')
806
    # status_res = v1.read_namespaced_service_status(helm_release_name, namespace, pretty='true')
807
    # print(status_res)
808
    if pod_name is not None:
809
        try:
810
            v1.delete_namespaced_pod(pod_name, namespace)
811
        except Exception as e:
812
            logging.error(str(e))
813
            logging.error("Exception when calling CoreV1Api->delete_namespaced_pod")
814
            res = False
815
            return res
816
        time.sleep(5)
817
        # check if restart successfully
818
        pods = v1.list_namespaced_pod(namespace)
819
        for i in pods.items:
820
            pod_name_tmp = i.metadata.name
821
            if pod_name_tmp.find(helm_release_name) != -1:
822
                logging.debug(pod_name_tmp)
823
                start_time = time.time()
824
                while time.time() - start_time > timeout:
825
                    status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
826
                    if status_res.status.phase == "Running":
827
                        break
828
                    time.sleep(1)
829
                if time.time() - start_time > timeout:
830
                    logging.error("Restart pod: %s timeout" % pod_name_tmp)
831
                    res = False
832
                    return res
833
    else:
834
        logging.error("Pod: %s not found" % helm_release_name)
835
        res = False
836
    return res
837