Passed
Push — master ( e4a70a...2b2ebc )
by
unknown
04:04 queued 02:05
created

utils.add_field_default()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 9
nop 3
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
1
import os
2
import sys
3
import random
4
import pdb
5
import string
6
import struct
7
import logging
8
import time, datetime
9
import copy
10
import numpy as np
11
from sklearn import preprocessing
12
from milvus import Milvus, DataType
13
14
port = 19530
15
epsilon = 0.000001
16
default_flush_interval = 1
17
big_flush_interval = 1000
18
dimension = 128
19
segment_row_count = 5000
20
default_float_vec_field_name = "float_vector"
21
default_binary_vec_field_name = "binary_vector"
22
23
# TODO:
24
all_index_types = [
25
    "FLAT",
26
    "IVF_FLAT",
27
    "IVF_SQ8",
28
    "IVF_SQ8_HYBRID",
29
    "IVF_PQ",
30
    "HNSW",
31
    # "NSG",
32
    "ANNOY",
33
    "BIN_FLAT",
34
    "BIN_IVF_FLAT"
35
]
36
37
default_index_params = [
38
    {"nlist": 1024},
39
    {"nlist": 1024},
40
    {"nlist": 1024},
41
    {"nlist": 1024},
42
    {"nlist": 1024, "m": 16},
43
    {"M": 48, "efConstruction": 500},
44
    # {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
45
    {"n_trees": 4},
46
    {"nlist": 1024},
47
    {"nlist": 1024}
48
]
49
50
51
def index_cpu_not_support():
52
    return ["IVF_SQ8_HYBRID"]
53
54
55
def binary_support():
56
    return ["BIN_FLAT", "BIN_IVF_FLAT"]
57
58
59
def delete_support():
60
    return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
61
62
63
def ivf():
64
    return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
65
66
67
def l2(x, y):
68
    return np.linalg.norm(np.array(x) - np.array(y))
69
70
71
def ip(x, y):
72
    return np.inner(np.array(x), np.array(y))
73
74
75
def jaccard(x, y):
76
    x = np.asarray(x, np.bool)
77
    y = np.asarray(y, np.bool)
78
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum())
79
80
81
def hamming(x, y):
82
    x = np.asarray(x, np.bool)
83
    y = np.asarray(y, np.bool)
84
    return np.bitwise_xor(x, y).sum()
85
86
87
def tanimoto(x, y):
88
    x = np.asarray(x, np.bool)
89
    y = np.asarray(y, np.bool)
90
    return -np.log2(np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum()))
91
92
93
def substructure(x, y):
94
    x = np.asarray(x, np.bool)
95
    y = np.asarray(y, np.bool)
96
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(y)
97
98
99
def superstructure(x, y):
100
    x = np.asarray(x, np.bool)
101
    y = np.asarray(y, np.bool)
102
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(x)
103
104
105
def get_milvus(host, port, uri=None, handler=None, **kwargs):
106
    if handler is None:
107
        handler = "GRPC"
108
    try_connect = kwargs.get("try_connect", True)
109
    if uri is not None:
110
        milvus = Milvus(uri=uri, handler=handler, try_connect=try_connect)
111
    else:
112
        milvus = Milvus(host=host, port=port, handler=handler, try_connect=try_connect)
113
    return milvus
114
115
116
def disable_flush(connect):
117
    connect.set_config("storage", "auto_flush_interval", big_flush_interval)
118
119
120
def enable_flush(connect):
121
    # reset auto_flush_interval=1
122
    connect.set_config("storage", "auto_flush_interval", default_flush_interval)
123
    config_value = connect.get_config("storage", "auto_flush_interval")
124
    assert config_value == str(default_flush_interval)
125
126
127
def gen_inaccuracy(num):
128
    return num / 255.0
129
130
131
def gen_vectors(num, dim, is_normal=True):
132
    vectors = [[random.random() for _ in range(dim)] for _ in range(num)]
133
    vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
134
    return vectors.tolist()
135
136
137
# def gen_vectors(num, dim, seed=np.random.RandomState(1234), is_normal=False):
138
#     xb = seed.rand(num, dim).astype("float32")
139
#     xb = preprocessing.normalize(xb, axis=1, norm='l2')
140
#     return xb.tolist()
141
142
143
def gen_binary_vectors(num, dim):
144
    raw_vectors = []
145
    binary_vectors = []
146
    for i in range(num):
147
        raw_vector = [random.randint(0, 1) for i in range(dim)]
148
        raw_vectors.append(raw_vector)
149
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
150
    return raw_vectors, binary_vectors
151
152
153
def gen_binary_sub_vectors(vectors, length):
154
    raw_vectors = []
155
    binary_vectors = []
156
    dim = len(vectors[0])
157
    for i in range(length):
158
        raw_vector = [0 for i in range(dim)]
159
        vector = vectors[i]
160
        for index, j in enumerate(vector):
161
            if j == 1:
162
                raw_vector[index] = 1
163
        raw_vectors.append(raw_vector)
164
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
165
    return raw_vectors, binary_vectors
166
167
168
def gen_binary_super_vectors(vectors, length):
169
    raw_vectors = []
170
    binary_vectors = []
171
    dim = len(vectors[0])
172
    for i in range(length):
173
        cnt_1 = np.count_nonzero(vectors[i])
174
        raw_vector = [1 for i in range(dim)]
175
        raw_vectors.append(raw_vector)
176
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
177
    return raw_vectors, binary_vectors
178
179
180
def gen_int_attr(row_num):
181
    return [random.randint(0, 255) for _ in range(row_num)]
182
183
184
def gen_float_attr(row_num):
185
    return [random.uniform(0, 255) for _ in range(row_num)]
186
187
188
def gen_unique_str(str_value=None):
189
    prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
190
    return "test_" + prefix if str_value is None else str_value + "_" + prefix
191
192
193
def gen_single_filter_fields():
194
    fields = []
195
    for data_type in DataType:
196
        if data_type in [DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]:
197
            fields.append({"field": data_type.name, "type": data_type})
198
    return fields
199
200
201
def gen_single_vector_fields():
202
    fields = []
203
    for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
204
        field = {"field": data_type.name, "type": data_type, "params": {"dim": dimension}}
205
        fields.append(field)
206
    return fields
207
208
209
def gen_default_fields():
210
    default_fields = {
211
        "fields": [
212
            {"field": "int64", "type": DataType.INT64},
213
            {"field": "float", "type": DataType.FLOAT},
214
            {"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": dimension}}
215
        ],
216
        "segment_row_count": segment_row_count
217
    }
218
    return default_fields
219
220
221
def gen_entities(nb, is_normal=False):
222
    vectors = gen_vectors(nb, dimension, is_normal)
223
    entities = [
224
        {"field": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
225
        {"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
226
        {"field": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
227
    ]
228
    return entities
229
230
231
def gen_binary_entities(nb):
232
    raw_vectors, vectors = gen_binary_vectors(nb, dimension)
233
    entities = [
234
        {"field": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
235
        {"field": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
236
        {"field": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "values": vectors}
237
    ]
238
    return raw_vectors, entities
239
240
241
def gen_entities_by_fields(fields, nb, dimension):
242
    entities = []
243
    for field in fields:
244
        if field["type"] in [DataType.INT32, DataType.INT64]:
245
            field_value = [1 for i in range(nb)]
246
        elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]:
247
            field_value = [3.0 for i in range(nb)]
248
        elif field["type"] == DataType.BINARY_VECTOR:
249
            field_value = gen_binary_vectors(nb, dimension)[1]
250
        elif field["type"] == DataType.FLOAT_VECTOR:
251
            field_value = gen_vectors(nb, dimension)
252
        field.update({"values": field_value})
0 ignored issues
show
introduced by
The variable field_value does not seem to be defined for all execution paths.
Loading history...
253
        entities.append(field)
254
    return entities
255
256
257
def assert_equal_entity(a, b):
258
    pass
259
260
261
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
262
                      metric_type=None):
263
    if rand_vector is True:
264
        dimension = len(entities[-1]["values"][0])
265
        query_vectors = gen_vectors(nq, dimension)
266
    else:
267
        query_vectors = entities[-1]["values"][:nq]
268
    must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}}
269
    if metric_type is not None:
270
        must_param["vector"][field_name]["metric_type"] = metric_type
271
    query = {
272
        "bool": {
273
            "must": [must_param]
274
        }
275
    }
276
    return query, query_vectors
277
278
279
def update_query_expr(src_query, keep_old=True, expr=None):
280
    tmp_query = copy.deepcopy(src_query)
281
    if expr is not None:
282
        tmp_query["bool"].update(expr)
283
    if keep_old is not True:
284
        tmp_query["bool"].pop("must")
285
    return tmp_query
286
287
288
def gen_default_vector_expr(default_query):
289
    return default_query["bool"]["must"][0]
290
291
292
def gen_default_term_expr(keyword="term", values=None):
293
    if values is None:
294
        values = [i for i in range(nb / 2)]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable nb does not seem to be defined.
Loading history...
295
    expr = {keyword: {"int64": {"values": values}}}
296
    return expr
297
298
299
def gen_default_range_expr(ranges=None):
300
    if ranges is None:
301
        ranges = {"GT": 1, "LT": nb / 2}
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable nb does not seem to be defined.
Loading history...
302
    expr = {"range": {"int64": {"ranges": ranges}}}
303
    return expr
304
305
306
def add_field_default(default_fields, type=DataType.INT64, field_name=None):
307
    tmp_fields = copy.deepcopy(default_fields)
308
    if field_name is None:
309
        field_name = gen_unique_str()
310
    field = {
311
        "field": field_name,
312
        "type": type
313
    }
314
    tmp_fields["fields"].append(field)
315
    return tmp_fields
316
317
318
def add_field(entities, field_name=None):
319
    nb = len(entities[0]["values"])
320
    tmp_entities = copy.deepcopy(entities)
321
    if field_name is None:
322
        field_name = gen_unique_str()
323
    field = {
324
        "field": field_name,
325
        "type": DataType.INT64,
326
        "values": [i for i in range(nb)]
327
    }
328
    tmp_entities.append(field)
329
    return tmp_entities
330
331
332
def add_vector_field(entities, is_normal=False):
333
    nb = len(entities[0]["values"])
334
    vectors = gen_vectors(nb, dimension, is_normal)
335
    field = {
336
        "field": gen_unique_str(),
337
        "type": DataType.FLOAT_VECTOR,
338
        "values": vectors
339
    }
340
    entities.append(field)
341
    return entities
342
343
344
# def update_fields_metric_type(fields, metric_type):
345
#     tmp_fields = copy.deepcopy(fields)
346
#     if metric_type in ["L2", "IP"]:
347
#         tmp_fields["fields"][-1]["type"] = DataType.FLOAT_VECTOR
348
#     else:
349
#         tmp_fields["fields"][-1]["type"] = DataType.BINARY_VECTOR
350
#     tmp_fields["fields"][-1]["params"]["metric_type"] = metric_type
351
#     return tmp_fields
352
353
354
def remove_field(entities):
355
    del entities[0]
356
    return entities
357
358
359
def remove_vector_field(entities):
360
    del entities[-1]
361
    return entities
362
363
364
def update_field_name(entities, old_name, new_name):
365
    for item in entities:
366
        if item["field"] == old_name:
367
            item["field"] = new_name
368
    return entities
369
370
371
def update_field_type(entities, old_name, new_name):
372
    for item in entities:
373
        if item["field"] == old_name:
374
            item["type"] = new_name
375
    return entities
376
377
378
def update_field_value(entities, old_type, new_value):
379
    for item in entities:
380
        if item["type"] == old_type:
381
            for i in item["values"]:
382
                item["values"][i] = new_value
383
    return entities
384
385
386
def add_vector_field(nb, dimension=dimension):
387
    field_name = gen_unique_str()
388
    field = {
389
        "field": field_name,
390
        "type": DataType.FLOAT_VECTOR,
391
        "values": gen_vectors(nb, dimension)
392
    }
393
    return field_name
394
395
396
def gen_segment_row_counts():
397
    sizes = [
398
        1,
399
        2,
400
        1024,
401
        4096
402
    ]
403
    return sizes
404
405
406
def gen_invalid_ips():
407
    ips = [
408
        # "255.0.0.0",
409
        # "255.255.0.0",
410
        # "255.255.255.0",
411
        # "255.255.255.255",
412
        "127.0.0",
413
        # "123.0.0.2",
414
        "12-s",
415
        " ",
416
        "12 s",
417
        "BB。A",
418
        " siede ",
419
        "(mn)",
420
        "中文",
421
        "a".join("a" for _ in range(256))
422
    ]
423
    return ips
424
425
426
def gen_invalid_uris():
427
    ip = None
428
    uris = [
429
        " ",
430
        "中文",
431
        # invalid protocol
432
        # "tc://%s:%s" % (ip, port),
433
        # "tcp%s:%s" % (ip, port),
434
435
        # # invalid port
436
        # "tcp://%s:100000" % ip,
437
        # "tcp://%s: " % ip,
438
        # "tcp://%s:19540" % ip,
439
        # "tcp://%s:-1" % ip,
440
        # "tcp://%s:string" % ip,
441
442
        # invalid ip
443
        "tcp:// :19530",
444
        # "tcp://123.0.0.1:%s" % port,
445
        "tcp://127.0.0:19530",
446
        # "tcp://255.0.0.0:%s" % port,
447
        # "tcp://255.255.0.0:%s" % port,
448
        # "tcp://255.255.255.0:%s" % port,
449
        # "tcp://255.255.255.255:%s" % port,
450
        "tcp://\n:19530",
451
    ]
452
    return uris
453
454
455
def gen_invalid_strs():
456
    strings = [
457
        1,
458
        [1],
459
        None,
460
        "12-s",
461
        " ",
462
        # "",
463
        # None,
464
        "12 s",
465
        "BB。A",
466
        "c|c",
467
        " siede ",
468
        "(mn)",
469
        "pip+",
470
        "=c",
471
        "中文",
472
        "a".join("a" for i in range(256))
473
    ]
474
    return strings
475
476
477
def gen_invalid_field_types():
478
    field_types = [
479
        # 1,
480
        "=c",
481
        # 0,
482
        None,
483
        "",
484
        "a".join("a" for i in range(256))
485
    ]
486
    return field_types
487
488
489
def gen_invalid_metric_types():
490
    metric_types = [
491
        1,
492
        "=c",
493
        0,
494
        None,
495
        "",
496
        "a".join("a" for i in range(256))
497
    ]
498
    return metric_types
499
500
501
# TODO:
502
def gen_invalid_ints():
503
    top_ks = [
504
        # 1.0,
505
        None,
506
        "stringg",
507
        [1, 2, 3],
508
        (1, 2),
509
        {"a": 1},
510
        " ",
511
        "",
512
        "String",
513
        "12-s",
514
        "BB。A",
515
        " siede ",
516
        "(mn)",
517
        "pip+",
518
        "=c",
519
        "中文",
520
        "a".join("a" for i in range(256))
521
    ]
522
    return top_ks
523
524
525
def gen_invalid_params():
526
    params = [
527
        9999999999,
528
        -1,
529
        # None,
530
        [1, 2, 3],
531
        (1, 2),
532
        {"a": 1},
533
        " ",
534
        "",
535
        "String",
536
        "12-s",
537
        "BB。A",
538
        " siede ",
539
        "(mn)",
540
        "pip+",
541
        "=c",
542
        "中文"
543
    ]
544
    return params
545
546
547
def gen_invalid_vectors():
548
    invalid_vectors = [
549
        "1*2",
550
        [],
551
        [1],
552
        [1, 2],
553
        [" "],
554
        ['a'],
555
        [None],
556
        None,
557
        (1, 2),
558
        {"a": 1},
559
        " ",
560
        "",
561
        "String",
562
        "12-s",
563
        "BB。A",
564
        " siede ",
565
        "(mn)",
566
        "pip+",
567
        "=c",
568
        "中文",
569
        "a".join("a" for i in range(256))
570
    ]
571
    return invalid_vectors
572
573
574
def gen_invaild_search_params():
575
    invalid_search_key = 100
576
    search_params = []
577
    for index_type in all_index_types:
578
        if index_type == "FLAT":
579
            continue
580
        search_params.append({"index_type": index_type, "search_params": {"invalid_key": invalid_search_key}})
581
        if index_type in delete_support():
582
            for nprobe in gen_invalid_params():
583
                ivf_search_params = {"index_type": index_type, "search_params": {"nprobe": nprobe}}
584
                search_params.append(ivf_search_params)
585
        elif index_type == "HNSW":
586
            for ef in gen_invalid_params():
587
                hnsw_search_param = {"index_type": index_type, "search_params": {"ef": ef}}
588
                search_params.append(hnsw_search_param)
589
        elif index_type == "NSG":
590
            for search_length in gen_invalid_params():
591
                nsg_search_param = {"index_type": index_type, "search_params": {"search_length": search_length}}
592
                search_params.append(nsg_search_param)
593
            search_params.append({"index_type": index_type, "search_params": {"invalid_key": 100}})
594
        elif index_type == "ANNOY":
595
            for search_k in gen_invalid_params():
596
                if isinstance(search_k, int):
597
                    continue
598
                annoy_search_param = {"index_type": index_type, "search_params": {"search_k": search_k}}
599
                search_params.append(annoy_search_param)
600
    return search_params
601
602
603
def gen_invalid_index():
604
    index_params = []
605
    for index_type in gen_invalid_strs():
606
        index_param = {"index_type": index_type, "params": {"nlist": 1024}}
607
        index_params.append(index_param)
608
    for nlist in gen_invalid_params():
609
        index_param = {"index_type": "IVF_FLAT", "params": {"nlist": nlist}}
610
        index_params.append(index_param)
611
    for M in gen_invalid_params():
612
        index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}}
613
        index_params.append(index_param)
614
    for efConstruction in gen_invalid_params():
615
        index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}}
616
        index_params.append(index_param)
617
    for search_length in gen_invalid_params():
618
        index_param = {"index_type": "NSG",
619
                       "params": {"search_length": search_length, "out_degree": 40, "candidate_pool_size": 50,
620
                                  "knng": 100}}
621
        index_params.append(index_param)
622
    for out_degree in gen_invalid_params():
623
        index_param = {"index_type": "NSG",
624
                       "params": {"search_length": 100, "out_degree": out_degree, "candidate_pool_size": 50,
625
                                  "knng": 100}}
626
        index_params.append(index_param)
627
    for candidate_pool_size in gen_invalid_params():
628
        index_param = {"index_type": "NSG", "params": {"search_length": 100, "out_degree": 40,
629
                                                       "candidate_pool_size": candidate_pool_size,
630
                                                       "knng": 100}}
631
        index_params.append(index_param)
632
    index_params.append({"index_type": "IVF_FLAT", "params": {"invalid_key": 1024}})
633
    index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}})
634
    index_params.append({"index_type": "NSG",
635
                         "params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
636
                                    "knng": 100}})
637
    for invalid_n_trees in gen_invalid_params():
638
        index_params.append({"index_type": "ANNOY", "params": {"n_trees": invalid_n_trees}})
639
640
    return index_params
641
642
643
def gen_index():
644
    nlists = [1, 1024, 16384]
645
    pq_ms = [128, 64, 32, 16, 8, 4]
646
    Ms = [5, 24, 48]
647
    efConstructions = [100, 300, 500]
648
    search_lengths = [10, 100, 300]
649
    out_degrees = [5, 40, 300]
650
    candidate_pool_sizes = [50, 100, 300]
651
    knngs = [5, 100, 300]
652
653
    index_params = []
654
    for index_type in all_index_types:
655
        if index_type in ["FLAT", "BIN_FLAT", "BIN_IVF_FLAT"]:
656
            index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}})
657
        elif index_type in ["IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID"]:
658
            ivf_params = [{"index_type": index_type, "index_param": {"nlist": nlist}} \
659
                          for nlist in nlists]
660
            index_params.extend(ivf_params)
661
        elif index_type == "IVF_PQ":
662
            IVFPQ_params = [{"index_type": index_type, "index_param": {"nlist": nlist, "m": m}} \
663
                            for nlist in nlists \
664
                            for m in pq_ms]
665
            index_params.extend(IVFPQ_params)
666
        elif index_type == "HNSW":
667
            hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \
668
                           for M in Ms \
669
                           for efConstruction in efConstructions]
670
            index_params.extend(hnsw_params)
671
        elif index_type == "NSG":
672
            nsg_params = [{"index_type": index_type,
673
                           "index_param": {"search_length": search_length, "out_degree": out_degree,
674
                                           "candidate_pool_size": candidate_pool_size, "knng": knng}} \
675
                          for search_length in search_lengths \
676
                          for out_degree in out_degrees \
677
                          for candidate_pool_size in candidate_pool_sizes \
678
                          for knng in knngs]
679
            index_params.extend(nsg_params)
680
681
    return index_params
682
683
684
def gen_simple_index():
685
    index_params = []
686
    for i in range(len(all_index_types)):
687
        if all_index_types[i] in binary_support():
688
            continue
689
        dic = {"index_type": all_index_types[i], "metric_type": "L2"}
690
        dic.update({"params": default_index_params[i]})
691
        index_params.append(dic)
692
    return index_params
693
694
695
def gen_binary_index():
696
    index_params = []
697
    for i in range(len(all_index_types)):
698
        if all_index_types[i] in binary_support():
699
            dic = {"index_type": all_index_types[i]}
700
            dic.update({"params": default_index_params[i]})
701
            index_params.append(dic)
702
    return index_params
703
704
705
def get_search_param(index_type):
706
    search_params = {"metric_type": "L2"}
707
    if index_type in ivf() or index_type in binary_support():
708
        search_params.update({"nprobe": 32})
709
    elif index_type == "HNSW":
710
        search_params.update({"ef": 64})
711
    elif index_type == "NSG":
712
        search_params.update({"search_length": 100})
713
    elif index_type == "ANNOY":
714
        search_params.update({"search_k": 100})
715
    else:
716
        logging.getLogger().error("Invalid index_type.")
717
        raise Exception("Invalid index_type.")
718
    return search_params
719
720
721
def assert_equal_vector(v1, v2):
722
    if len(v1) != len(v2):
723
        assert False
724
    for i in range(len(v1)):
725
        assert abs(v1[i] - v2[i]) < epsilon
726
727
728
def restart_server(helm_release_name):
729
    res = True
730
    timeout = 120
731
    from kubernetes import client, config
732
    client.rest.logger.setLevel(logging.WARNING)
733
734
    namespace = "milvus"
735
    # service_name = "%s.%s.svc.cluster.local" % (helm_release_name, namespace)
736
    config.load_kube_config()
737
    v1 = client.CoreV1Api()
738
    pod_name = None
739
    # config_map_names = v1.list_namespaced_config_map(namespace, pretty='true')
740
    # body = {"replicas": 0}
741
    pods = v1.list_namespaced_pod(namespace)
742
    for i in pods.items:
743
        if i.metadata.name.find(helm_release_name) != -1 and i.metadata.name.find("mysql") == -1:
744
            pod_name = i.metadata.name
745
            break
746
            # v1.patch_namespaced_config_map(config_map_name, namespace, body, pretty='true')
747
    # status_res = v1.read_namespaced_service_status(helm_release_name, namespace, pretty='true')
748
    # print(status_res)
749
    if pod_name is not None:
750
        try:
751
            v1.delete_namespaced_pod(pod_name, namespace)
752
        except Exception as e:
753
            logging.error(str(e))
754
            logging.error("Exception when calling CoreV1Api->delete_namespaced_pod")
755
            res = False
756
            return res
757
        time.sleep(5)
758
        # check if restart successfully
759
        pods = v1.list_namespaced_pod(namespace)
760
        for i in pods.items:
761
            pod_name_tmp = i.metadata.name
762
            if pod_name_tmp.find(helm_release_name) != -1:
763
                logging.debug(pod_name_tmp)
764
                start_time = time.time()
765
                while time.time() - start_time > timeout:
766
                    status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
767
                    if status_res.status.phase == "Running":
768
                        break
769
                    time.sleep(1)
770
                if time.time() - start_time > timeout:
771
                    logging.error("Restart pod: %s timeout" % pod_name_tmp)
772
                    res = False
773
                    return res
774
    else:
775
        logging.error("Pod: %s not found" % helm_release_name)
776
        res = False
777
    return res
778