Passed
Push — master ( b86822...b6e78f )
by
unknown
02:07
created

utils.gen_binary_entities_rows()   A

Complexity

Conditions 4

Size

Total Lines 21
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 18
nop 2
dl 0
loc 21
rs 9.5
c 0
b 0
f 0
1
import os
2
import sys
3
import random
4
import pdb
5
import string
6
import struct
7
import logging
8
import threading
9
import time
10
import copy
11
import numpy as np
12
from sklearn import preprocessing
13
from milvus import Milvus, DataType
14
15
port = 19530
16
epsilon = 0.000001
17
namespace = "milvus"
18
19
default_flush_interval = 1
20
big_flush_interval = 1000
21
default_drop_interval = 3
22
default_dim = 128
23
default_nb = 1200
24
default_top_k = 10
25
max_top_k = 16384
26
max_partition_num = 256
27
default_segment_row_limit = 1000
28
default_server_segment_row_limit = 1024 * 512
29
default_float_vec_field_name = "float_vector"
30
default_binary_vec_field_name = "binary_vector"
31
default_partition_name = "_default"
32
default_tag = "1970_01_01"
33
34
# TODO:
35
# TODO: disable RHNSW_SQ/PQ in 0.11.0
36
all_index_types = [
37
    "FLAT",
38
    "IVF_FLAT",
39
    "IVF_SQ8",
40
    "IVF_SQ8_HYBRID",
41
    "IVF_PQ",
42
    "HNSW",
43
    # "NSG",
44
    "ANNOY",
45
    "RHNSW_PQ",
46
    "RHNSW_SQ",
47
    "BIN_FLAT",
48
    "BIN_IVF_FLAT"
49
]
50
51
default_index_params = [
52
    {"nlist": 128},
53
    {"nlist": 128},
54
    {"nlist": 128},
55
    {"nlist": 128},
56
    {"nlist": 128, "m": 16},
57
    {"M": 48, "efConstruction": 500},
58
    # {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
59
    {"n_trees": 50},
60
    {"M": 48, "efConstruction": 500, "PQM": 64},
61
    {"M": 48, "efConstruction": 500},
62
    {"nlist": 128},
63
    {"nlist": 128}
64
]
65
66
67
def index_cpu_not_support():
68
    return ["IVF_SQ8_HYBRID"]
69
70
71
def binary_support():
72
    return ["BIN_FLAT", "BIN_IVF_FLAT"]
73
74
75
def delete_support():
76
    return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
77
78
79
def ivf():
80
    return ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID", "IVF_PQ"]
81
82
83
def skip_pq():
84
    return ["IVF_PQ", "RHNSW_PQ", "RHNSW_SQ"]
85
86
87
def binary_metrics():
88
    return ["JACCARD", "HAMMING", "TANIMOTO", "SUBSTRUCTURE", "SUPERSTRUCTURE"]
89
90
91
def structure_metrics():
92
    return ["SUBSTRUCTURE", "SUPERSTRUCTURE"]
93
94
95
def l2(x, y):
96
    return np.linalg.norm(np.array(x) - np.array(y))
97
98
99
def ip(x, y):
100
    return np.inner(np.array(x), np.array(y))
101
102
103
def jaccard(x, y):
104
    x = np.asarray(x, np.bool)
105
    y = np.asarray(y, np.bool)
106
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum())
107
108
109
def hamming(x, y):
110
    x = np.asarray(x, np.bool)
111
    y = np.asarray(y, np.bool)
112
    return np.bitwise_xor(x, y).sum()
113
114
115
def tanimoto(x, y):
116
    x = np.asarray(x, np.bool)
117
    y = np.asarray(y, np.bool)
118
    return -np.log2(np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum()))
119
120
121
def substructure(x, y):
122
    x = np.asarray(x, np.bool)
123
    y = np.asarray(y, np.bool)
124
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(y)
125
126
127
def superstructure(x, y):
128
    x = np.asarray(x, np.bool)
129
    y = np.asarray(y, np.bool)
130
    return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(x)
131
132
133
def get_milvus(host, port, uri=None, handler=None, **kwargs):
134
    if handler is None:
135
        handler = "GRPC"
136
    try_connect = kwargs.get("try_connect", True)
137
    if uri is not None:
138
        milvus = Milvus(uri=uri, handler=handler, try_connect=try_connect)
139
    else:
140
        milvus = Milvus(host=host, port=port, handler=handler, try_connect=try_connect)
141
    return milvus
142
143
144
def reset_build_index_threshold(connect):
145
    connect.set_config("engine", "build_index_threshold", 1024)
146
147
148
def disable_flush(connect):
149
    connect.set_config("storage", "auto_flush_interval", big_flush_interval)
150
151
152
def enable_flush(connect):
153
    # reset auto_flush_interval=1
154
    connect.set_config("storage", "auto_flush_interval", default_flush_interval)
155
    config_value = connect.get_config("storage", "auto_flush_interval")
156
    assert config_value == str(default_flush_interval)
157
158
159
def gen_inaccuracy(num):
160
    return num / 255.0
161
162
163
def gen_vectors(num, dim, is_normal=True):
164
    vectors = [[random.random() for _ in range(dim)] for _ in range(num)]
165
    vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
166
    return vectors.tolist()
167
168
169
# def gen_vectors(num, dim, seed=np.random.RandomState(1234), is_normal=False):
170
#     xb = seed.rand(num, dim).astype("float32")
171
#     xb = preprocessing.normalize(xb, axis=1, norm='l2')
172
#     return xb.tolist()
173
174
175
def gen_binary_vectors(num, dim):
176
    raw_vectors = []
177
    binary_vectors = []
178
    for i in range(num):
179
        raw_vector = [random.randint(0, 1) for i in range(dim)]
180
        raw_vectors.append(raw_vector)
181
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
182
    return raw_vectors, binary_vectors
183
184
185 View Code Duplication
def gen_binary_sub_vectors(vectors, length):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
186
    raw_vectors = []
187
    binary_vectors = []
188
    dim = len(vectors[0])
189
    for i in range(length):
190
        raw_vector = [0 for i in range(dim)]
191
        vector = vectors[i]
192
        for index, j in enumerate(vector):
193
            if j == 1:
194
                raw_vector[index] = 1
195
        raw_vectors.append(raw_vector)
196
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
197
    return raw_vectors, binary_vectors
198
199
200 View Code Duplication
def gen_binary_super_vectors(vectors, length):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
201
    raw_vectors = []
202
    binary_vectors = []
203
    dim = len(vectors[0])
204
    for i in range(length):
205
        cnt_1 = np.count_nonzero(vectors[i])
206
        raw_vector = [1 for i in range(dim)]
207
        raw_vectors.append(raw_vector)
208
        binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
209
    return raw_vectors, binary_vectors
210
211
212
def gen_int_attr(row_num):
213
    return [random.randint(0, 255) for _ in range(row_num)]
214
215
216
def gen_float_attr(row_num):
217
    return [random.uniform(0, 255) for _ in range(row_num)]
218
219
220
def gen_unique_str(str_value=None):
221
    prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
222
    return "test_" + prefix if str_value is None else str_value + "_" + prefix
223
224
225
def gen_single_filter_fields():
226
    fields = []
227
    for data_type in DataType:
228
        if data_type in [DataType.INT32, DataType.INT64, DataType.FLOAT, DataType.DOUBLE]:
229
            fields.append({"name": data_type.name, "type": data_type})
230
    return fields
231
232
233
def gen_single_vector_fields():
234
    fields = []
235
    for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
236
        field = {"name": data_type.name, "type": data_type, "params": {"dim": default_dim}}
237
        fields.append(field)
238
    return fields
239
240
241
def gen_default_fields(auto_id=True):
242
    default_fields = {
243
        "fields": [
244
            {"name": "int64", "type": DataType.INT64},
245
            {"name": "float", "type": DataType.FLOAT},
246
            {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "params": {"dim": default_dim}},
247
        ],
248
        "segment_row_limit": default_segment_row_limit,
249
        "auto_id": auto_id
250
    }
251
    return default_fields
252
253
254
def gen_binary_default_fields(auto_id=True):
255
    default_fields = {
256
        "fields": [
257
            {"name": "int64", "type": DataType.INT64},
258
            {"name": "float", "type": DataType.FLOAT},
259
            {"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "params": {"dim": default_dim}}
260
        ],
261
        "segment_row_limit": default_segment_row_limit,
262
        "auto_id": auto_id
263
    }
264
    return default_fields
265
266
267
def gen_entities(nb, is_normal=False):
268
    vectors = gen_vectors(nb, default_dim, is_normal)
269
    entities = [
270
        {"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
271
        {"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
272
        {"name": default_float_vec_field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}
273
    ]
274
    return entities
275
276
277
def gen_entities_new(nb, is_normal=False):
278
    vectors = gen_vectors(nb, default_dim, is_normal)
279
    entities = [
280
        {"name": "int64", "values": [i for i in range(nb)]},
281
        {"name": "float", "values": [float(i) for i in range(nb)]},
282
        {"name": default_float_vec_field_name, "values": vectors}
283
    ]
284
    return entities
285
286
287
def gen_entities_rows(nb, is_normal=False, _id=True):
288
    vectors = gen_vectors(nb, default_dim, is_normal)
289
    entities = []
290
    if not _id:
291
        for i in range(nb):
292
            entity = {
293
                "_id": i,
294
                "int64": i,
295
                "float": float(i),
296
                default_float_vec_field_name: vectors[i]
297
            }
298
            entities.append(entity)
299
    else:
300
        for i in range(nb):
301
            entity = {
302
                "int64": i,
303
                "float": float(i),
304
                default_float_vec_field_name: vectors[i]
305
            }
306
            entities.append(entity)
307
    return entities
308
309
310
def gen_binary_entities(nb):
311
    raw_vectors, vectors = gen_binary_vectors(nb, default_dim)
312
    entities = [
313
        {"name": "int64", "type": DataType.INT64, "values": [i for i in range(nb)]},
314
        {"name": "float", "type": DataType.FLOAT, "values": [float(i) for i in range(nb)]},
315
        {"name": default_binary_vec_field_name, "type": DataType.BINARY_VECTOR, "values": vectors}
316
    ]
317
    return raw_vectors, entities
318
319
320
def gen_binary_entities_new(nb):
321
    raw_vectors, vectors = gen_binary_vectors(nb, default_dim)
322
    entities = [
323
        {"name": "int64", "values": [i for i in range(nb)]},
324
        {"name": "float", "values": [float(i) for i in range(nb)]},
325
        {"name": default_binary_vec_field_name, "values": vectors}
326
    ]
327
    return raw_vectors, entities
328
329
330
def gen_binary_entities_rows(nb, _id=True):
331
    raw_vectors, vectors = gen_binary_vectors(nb, default_dim)
332
    entities = []
333
    if not _id:
334
        for i in range(nb):
335
            entity = {
336
                "_id": i,
337
                "int64": i,
338
                "float": float(i),
339
                default_binary_vec_field_name: vectors[i]
340
            }
341
            entities.append(entity)
342
    else:
343
        for i in range(nb):
344
            entity = {
345
                "int64": i,
346
                "float": float(i),
347
                default_binary_vec_field_name: vectors[i]
348
            }
349
            entities.append(entity)
350
    return raw_vectors, entities
351
352
353 View Code Duplication
def gen_entities_by_fields(fields, nb, dim):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
354
    entities = []
355
    for field in fields:
356
        if field["type"] in [DataType.INT32, DataType.INT64]:
357
            field_value = [1 for i in range(nb)]
358
        elif field["type"] in [DataType.FLOAT, DataType.DOUBLE]:
359
            field_value = [3.0 for i in range(nb)]
360
        elif field["type"] == DataType.BINARY_VECTOR:
361
            field_value = gen_binary_vectors(nb, dim)[1]
362
        elif field["type"] == DataType.FLOAT_VECTOR:
363
            field_value = gen_vectors(nb, dim)
364
        field.update({"values": field_value})
0 ignored issues
show
introduced by
The variable field_value does not seem to be defined for all execution paths.
Loading history...
365
        entities.append(field)
366
    return entities
367
368
369
def assert_equal_entity(a, b):
370
    pass
371
372
373
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
374
                      metric_type="L2", replace_vecs=None):
375
    if rand_vector is True:
376
        dimension = len(entities[-1]["values"][0])
377
        query_vectors = gen_vectors(nq, dimension)
378
    else:
379
        query_vectors = entities[-1]["values"][:nq]
380
    if replace_vecs:
381
        query_vectors = replace_vecs
382
    must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}}
383
    must_param["vector"][field_name]["metric_type"] = metric_type
384
    query = {
385
        "bool": {
386
            "must": [must_param]
387
        }
388
    }
389
    return query, query_vectors
390
391
392
def update_query_expr(src_query, keep_old=True, expr=None):
393
    tmp_query = copy.deepcopy(src_query)
394
    if expr is not None:
395
        tmp_query["bool"].update(expr)
396
    if keep_old is not True:
397
        tmp_query["bool"].pop("must")
398
    return tmp_query
399
400
401
def gen_default_vector_expr(default_query):
402
    return default_query["bool"]["must"][0]
403
404
405
def gen_default_term_expr(keyword="term", field="int64", values=None):
406
    if values is None:
407
        values = [i for i in range(default_nb // 2)]
408
    expr = {keyword: {field: {"values": values}}}
409
    return expr
410
411
412
def update_term_expr(src_term, terms):
413
    tmp_term = copy.deepcopy(src_term)
414
    for term in terms:
415
        tmp_term["term"].update(term)
416
    return tmp_term
417
418
419
def gen_default_range_expr(keyword="range", field="int64", ranges=None):
420
    if ranges is None:
421
        ranges = {"GT": 1, "LT": default_nb // 2}
422
    expr = {keyword: {field: ranges}}
423
    return expr
424
425
426
def update_range_expr(src_range, ranges):
427
    tmp_range = copy.deepcopy(src_range)
428
    for range in ranges:
429
        tmp_range["range"].update(range)
430
    return tmp_range
431
432
433
def gen_invalid_range():
434
    range = [
435
        {"range": 1},
436
        {"range": {}},
437
        {"range": []},
438
        {"range": {"range": {"int64": {"GT": 0, "LT": default_nb // 2}}}}
439
    ]
440
    return range
441
442
443
def gen_valid_ranges():
444
    ranges = [
445
        {"GT": 0, "LT": default_nb // 2},
446
        {"GT": default_nb // 2, "LT": default_nb * 2},
447
        {"GT": 0},
448
        {"LT": default_nb},
449
        {"GT": -1, "LT": default_top_k},
450
    ]
451
    return ranges
452
453
454
def gen_invalid_term():
455
    terms = [
456
        {"term": 1},
457
        {"term": []},
458
        {"term": {}},
459
        {"term": {"term": {"int64": {"values": [i for i in range(default_nb // 2)]}}}}
460
    ]
461
    return terms
462
463
464 View Code Duplication
def add_field_default(default_fields, type=DataType.INT64, field_name=None):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
465
    tmp_fields = copy.deepcopy(default_fields)
466
    if field_name is None:
467
        field_name = gen_unique_str()
468
    field = {
469
        "name": field_name,
470
        "type": type
471
    }
472
    tmp_fields["fields"].append(field)
473
    return tmp_fields
474
475
476 View Code Duplication
def add_field(entities, field_name=None):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
477
    nb = len(entities[0]["values"])
478
    tmp_entities = copy.deepcopy(entities)
479
    if field_name is None:
480
        field_name = gen_unique_str()
481
    field = {
482
        "name": field_name,
483
        "type": DataType.INT64,
484
        "values": [i for i in range(nb)]
485
    }
486
    tmp_entities.append(field)
487
    return tmp_entities
488
489
490 View Code Duplication
def add_vector_field(entities, is_normal=False):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
491
    nb = len(entities[0]["values"])
492
    vectors = gen_vectors(nb, default_dim, is_normal)
493
    field = {
494
        "name": gen_unique_str(),
495
        "type": DataType.FLOAT_VECTOR,
496
        "values": vectors
497
    }
498
    entities.append(field)
499
    return entities
500
501
502
# def update_fields_metric_type(fields, metric_type):
503
#     tmp_fields = copy.deepcopy(fields)
504
#     if metric_type in ["L2", "IP"]:
505
#         tmp_fields["fields"][-1]["type"] = DataType.FLOAT_VECTOR
506
#     else:
507
#         tmp_fields["fields"][-1]["type"] = DataType.BINARY_VECTOR
508
#     tmp_fields["fields"][-1]["params"]["metric_type"] = metric_type
509
#     return tmp_fields
510
511
512
def remove_field(entities):
513
    del entities[0]
514
    return entities
515
516
517
def remove_vector_field(entities):
518
    del entities[-1]
519
    return entities
520
521
522
def update_field_name(entities, old_name, new_name):
523
    tmp_entities = copy.deepcopy(entities)
524
    for item in tmp_entities:
525
        if item["name"] == old_name:
526
            item["name"] = new_name
527
    return tmp_entities
528
529
530
def update_field_type(entities, old_name, new_name):
531
    tmp_entities = copy.deepcopy(entities)
532
    for item in tmp_entities:
533
        if item["name"] == old_name:
534
            item["type"] = new_name
535
    return tmp_entities
536
537
538
def update_field_value(entities, old_type, new_value):
539
    tmp_entities = copy.deepcopy(entities)
540
    for item in tmp_entities:
541
        if item["type"] == old_type:
542
            for index, value in enumerate(item["values"]):
543
                item["values"][index] = new_value
544
    return tmp_entities
545
546
547
def update_field_name_row(entities, old_name, new_name):
548
    tmp_entities = copy.deepcopy(entities)
549
    for item in tmp_entities:
550
        if old_name in item:
551
            item[new_name] = item[old_name]
552
            item.pop(old_name)
553
        else:
554
            raise Exception("Field %s not in field" % old_name)
555
    return tmp_entities
556
557
558
def update_field_type_row(entities, old_name, new_name):
559
    tmp_entities = copy.deepcopy(entities)
560
    for item in tmp_entities:
561
        if old_name in item:
562
            item["type"] = new_name
563
    return tmp_entities
564
565
566
def add_vector_field(nb, dimension=default_dim):
567
    field_name = gen_unique_str()
568
    field = {
569
        "name": field_name,
570
        "type": DataType.FLOAT_VECTOR,
571
        "values": gen_vectors(nb, dimension)
572
    }
573
    return field_name
574
575
576
def gen_segment_row_limits():
577
    sizes = [
578
        1024,
579
        4096
580
    ]
581
    return sizes
582
583
584
def gen_invalid_ips():
585
    ips = [
586
        # "255.0.0.0",
587
        # "255.255.0.0",
588
        # "255.255.255.0",
589
        # "255.255.255.255",
590
        "127.0.0",
591
        # "123.0.0.2",
592
        "12-s",
593
        " ",
594
        "12 s",
595
        "BB。A",
596
        " siede ",
597
        "(mn)",
598
        "中文",
599
        "a".join("a" for _ in range(256))
600
    ]
601
    return ips
602
603
604
def gen_invalid_uris():
605
    ip = None
606
    uris = [
607
        " ",
608
        "中文",
609
        # invalid protocol
610
        # "tc://%s:%s" % (ip, port),
611
        # "tcp%s:%s" % (ip, port),
612
613
        # # invalid port
614
        # "tcp://%s:100000" % ip,
615
        # "tcp://%s: " % ip,
616
        # "tcp://%s:19540" % ip,
617
        # "tcp://%s:-1" % ip,
618
        # "tcp://%s:string" % ip,
619
620
        # invalid ip
621
        "tcp:// :19530",
622
        # "tcp://123.0.0.1:%s" % port,
623
        "tcp://127.0.0:19530",
624
        # "tcp://255.0.0.0:%s" % port,
625
        # "tcp://255.255.0.0:%s" % port,
626
        # "tcp://255.255.255.0:%s" % port,
627
        # "tcp://255.255.255.255:%s" % port,
628
        "tcp://\n:19530",
629
    ]
630
    return uris
631
632
633
def gen_invalid_strs():
634
    strings = [
635
        1,
636
        [1],
637
        None,
638
        "12-s",
639
        # " ",
640
        # "",
641
        # None,
642
        "12 s",
643
        "(mn)",
644
        "中文",
645
        "a".join("a" for i in range(256))
646
    ]
647
    return strings
648
649
650
def gen_invalid_field_types():
651
    field_types = [
652
        # 1,
653
        "=c",
654
        # 0,
655
        None,
656
        "",
657
        "a".join("a" for i in range(256))
658
    ]
659
    return field_types
660
661
662
def gen_invalid_metric_types():
663
    metric_types = [
664
        1,
665
        "=c",
666
        0,
667
        None,
668
        "",
669
        "a".join("a" for i in range(256))
670
    ]
671
    return metric_types
672
673
674
# TODO:
675
def gen_invalid_ints():
676
    int_values = [
677
        # 1.0,
678
        None,
679
        [1, 2, 3],
680
        " ",
681
        "",
682
        -1,
683
        "String",
684
        "=c",
685
        "中文",
686
        "a".join("a" for i in range(256))
687
    ]
688
    return int_values
689
690
691
def gen_invalid_params():
692
    params = [
693
        9999999999,
694
        -1,
695
        # None,
696
        [1, 2, 3],
697
        " ",
698
        "",
699
        "String",
700
        "中文"
701
    ]
702
    return params
703
704
705
def gen_invalid_vectors():
706
    invalid_vectors = [
707
        "1*2",
708
        [],
709
        [1],
710
        [1, 2],
711
        [" "],
712
        ['a'],
713
        [None],
714
        None,
715
        (1, 2),
716
        {"a": 1},
717
        " ",
718
        "",
719
        "String",
720
        " siede ",
721
        "中文",
722
        "a".join("a" for i in range(256))
723
    ]
724
    return invalid_vectors
725
726
727 View Code Duplication
def gen_invaild_search_params():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
728
    invalid_search_key = 100
729
    search_params = []
730
    for index_type in all_index_types:
731
        if index_type == "FLAT":
732
            continue
733
        search_params.append({"index_type": index_type, "search_params": {"invalid_key": invalid_search_key}})
734
        if index_type in delete_support():
735
            for nprobe in gen_invalid_params():
736
                ivf_search_params = {"index_type": index_type, "search_params": {"nprobe": nprobe}}
737
                search_params.append(ivf_search_params)
738
        elif index_type in ["HNSW", "RHNSW_PQ", "RHNSW_SQ"]:
739
            for ef in gen_invalid_params():
740
                hnsw_search_param = {"index_type": index_type, "search_params": {"ef": ef}}
741
                search_params.append(hnsw_search_param)
742
        elif index_type == "NSG":
743
            for search_length in gen_invalid_params():
744
                nsg_search_param = {"index_type": index_type, "search_params": {"search_length": search_length}}
745
                search_params.append(nsg_search_param)
746
            search_params.append({"index_type": index_type, "search_params": {"invalid_key": 100}})
747
        elif index_type == "ANNOY":
748
            for search_k in gen_invalid_params():
749
                if isinstance(search_k, int):
750
                    continue
751
                annoy_search_param = {"index_type": index_type, "search_params": {"search_k": search_k}}
752
                search_params.append(annoy_search_param)
753
    return search_params
754
755
756 View Code Duplication
def gen_invalid_index():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
757
    index_params = []
758
    for index_type in gen_invalid_strs():
759
        index_param = {"index_type": index_type, "params": {"nlist": 1024}}
760
        index_params.append(index_param)
761
    for nlist in gen_invalid_params():
762
        index_param = {"index_type": "IVF_FLAT", "params": {"nlist": nlist}}
763
        index_params.append(index_param)
764
    for M in gen_invalid_params():
765
        index_param = {"index_type": "HNSW", "params": {"M": M, "efConstruction": 100}}
766
        index_param = {"index_type": "RHNSW_PQ", "params": {"M": M, "efConstruction": 100}}
767
        index_param = {"index_type": "RHNSW_SQ", "params": {"M": M, "efConstruction": 100}}
768
        index_params.append(index_param)
769
    for efConstruction in gen_invalid_params():
770
        index_param = {"index_type": "HNSW", "params": {"M": 16, "efConstruction": efConstruction}}
771
        index_param = {"index_type": "RHNSW_PQ", "params": {"M": 16, "efConstruction": efConstruction}}
772
        index_param = {"index_type": "RHNSW_SQ", "params": {"M": 16, "efConstruction": efConstruction}}
773
        index_params.append(index_param)
774
    for search_length in gen_invalid_params():
775
        index_param = {"index_type": "NSG",
776
                       "params": {"search_length": search_length, "out_degree": 40, "candidate_pool_size": 50,
777
                                  "knng": 100}}
778
        index_params.append(index_param)
779
    for out_degree in gen_invalid_params():
780
        index_param = {"index_type": "NSG",
781
                       "params": {"search_length": 100, "out_degree": out_degree, "candidate_pool_size": 50,
782
                                  "knng": 100}}
783
        index_params.append(index_param)
784
    for candidate_pool_size in gen_invalid_params():
785
        index_param = {"index_type": "NSG", "params": {"search_length": 100, "out_degree": 40,
786
                                                       "candidate_pool_size": candidate_pool_size,
787
                                                       "knng": 100}}
788
        index_params.append(index_param)
789
    index_params.append({"index_type": "IVF_FLAT", "params": {"invalid_key": 1024}})
790
    index_params.append({"index_type": "HNSW", "params": {"invalid_key": 16, "efConstruction": 100}})
791
    index_params.append({"index_type": "RHNSW_PQ", "params": {"invalid_key": 16, "efConstruction": 100}})
792
    index_params.append({"index_type": "RHNSW_SQ", "params": {"invalid_key": 16, "efConstruction": 100}})
793
    index_params.append({"index_type": "NSG",
794
                         "params": {"invalid_key": 100, "out_degree": 40, "candidate_pool_size": 300,
795
                                    "knng": 100}})
796
    for invalid_n_trees in gen_invalid_params():
797
        index_params.append({"index_type": "ANNOY", "params": {"n_trees": invalid_n_trees}})
798
799
    return index_params
800
801
802 View Code Duplication
def gen_index():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
803
    nlists = [1, 1024, 16384]
804
    pq_ms = [128, 64, 32, 16, 8, 4]
805
    Ms = [5, 24, 48]
806
    efConstructions = [100, 300, 500]
807
    search_lengths = [10, 100, 300]
808
    out_degrees = [5, 40, 300]
809
    candidate_pool_sizes = [50, 100, 300]
810
    knngs = [5, 100, 300]
811
812
    index_params = []
813
    for index_type in all_index_types:
814
        if index_type in ["FLAT", "BIN_FLAT", "BIN_IVF_FLAT"]:
815
            index_params.append({"index_type": index_type, "index_param": {"nlist": 1024}})
816
        elif index_type in ["IVF_FLAT", "IVF_SQ8", "IVF_SQ8_HYBRID"]:
817
            ivf_params = [{"index_type": index_type, "index_param": {"nlist": nlist}} \
818
                          for nlist in nlists]
819
            index_params.extend(ivf_params)
820
        elif index_type == "IVF_PQ":
821
            IVFPQ_params = [{"index_type": index_type, "index_param": {"nlist": nlist, "m": m}} \
822
                            for nlist in nlists \
823
                            for m in pq_ms]
824
            index_params.extend(IVFPQ_params)
825
        elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
826
            hnsw_params = [{"index_type": index_type, "index_param": {"M": M, "efConstruction": efConstruction}} \
827
                           for M in Ms \
828
                           for efConstruction in efConstructions]
829
            index_params.extend(hnsw_params)
830
        elif index_type == "NSG":
831
            nsg_params = [{"index_type": index_type,
832
                           "index_param": {"search_length": search_length, "out_degree": out_degree,
833
                                           "candidate_pool_size": candidate_pool_size, "knng": knng}} \
834
                          for search_length in search_lengths \
835
                          for out_degree in out_degrees \
836
                          for candidate_pool_size in candidate_pool_sizes \
837
                          for knng in knngs]
838
            index_params.extend(nsg_params)
839
840
    return index_params
841
842
843
def gen_simple_index():
844
    index_params = []
845
    for i in range(len(all_index_types)):
846
        if all_index_types[i] in binary_support():
847
            continue
848
        dic = {"index_type": all_index_types[i], "metric_type": "L2"}
849
        dic.update({"params": default_index_params[i]})
850
        index_params.append(dic)
851
    return index_params
852
853
854
def gen_binary_index():
855
    index_params = []
856
    for i in range(len(all_index_types)):
857
        if all_index_types[i] in binary_support():
858
            dic = {"index_type": all_index_types[i]}
859
            dic.update({"params": default_index_params[i]})
860
            index_params.append(dic)
861
    return index_params
862
863
864 View Code Duplication
def get_search_param(index_type, metric_type="L2"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
865
    search_params = {"metric_type": metric_type}
866
    if index_type in ivf() or index_type in binary_support():
867
        search_params.update({"nprobe": 64})
868
    elif index_type in ["HNSW", "RHNSW_SQ", "RHNSW_PQ"]:
869
        search_params.update({"ef": 64})
870
    elif index_type == "NSG":
871
        search_params.update({"search_length": 100})
872
    elif index_type == "ANNOY":
873
        search_params.update({"search_k": 1000})
874
    else:
875
        logging.getLogger().error("Invalid index_type.")
876
        raise Exception("Invalid index_type.")
877
    return search_params
878
879
880
def assert_equal_vector(v1, v2):
881
    if len(v1) != len(v2):
882
        assert False
883
    for i in range(len(v1)):
884
        assert abs(v1[i] - v2[i]) < epsilon
885
886
887 View Code Duplication
def restart_server(helm_release_name):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
888
    res = True
889
    timeout = 120
890
    from kubernetes import client, config
891
    client.rest.logger.setLevel(logging.WARNING)
892
893
    # service_name = "%s.%s.svc.cluster.local" % (helm_release_name, namespace)
894
    config.load_kube_config()
895
    v1 = client.CoreV1Api()
896
    pod_name = None
897
    # config_map_names = v1.list_namespaced_config_map(namespace, pretty='true')
898
    # body = {"replicas": 0}
899
    pods = v1.list_namespaced_pod(namespace)
900
    for i in pods.items:
901
        if i.metadata.name.find(helm_release_name) != -1 and i.metadata.name.find("mysql") == -1:
902
            pod_name = i.metadata.name
903
            break
904
            # v1.patch_namespaced_config_map(config_map_name, namespace, body, pretty='true')
905
    # status_res = v1.read_namespaced_service_status(helm_release_name, namespace, pretty='true')
906
    logging.getLogger().debug("Pod name: %s" % pod_name)
907
    if pod_name is not None:
908
        try:
909
            v1.delete_namespaced_pod(pod_name, namespace)
910
        except Exception as e:
911
            logging.error(str(e))
912
            logging.error("Exception when calling CoreV1Api->delete_namespaced_pod")
913
            res = False
914
            return res
915
        logging.error("Sleep 10s after pod deleted")
916
        time.sleep(10)
917
        # check if restart successfully
918
        pods = v1.list_namespaced_pod(namespace)
919
        for i in pods.items:
920
            pod_name_tmp = i.metadata.name
921
            logging.error(pod_name_tmp)
922
            if pod_name_tmp == pod_name:
923
                continue
924
            elif pod_name_tmp.find(helm_release_name) == -1 or pod_name_tmp.find("mysql") != -1:
925
                continue
926
            else:
927
                status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
928
                logging.error(status_res.status.phase)
929
                start_time = time.time()
930
                ready_break = False
931
                while time.time() - start_time <= timeout:
932
                    logging.error(time.time())
933
                    status_res = v1.read_namespaced_pod_status(pod_name_tmp, namespace, pretty='true')
934
                    if status_res.status.phase == "Running":
935
                        logging.error("Already running")
936
                        ready_break = True
937
                        time.sleep(10)
938
                        break
939
                    else:
940
                        time.sleep(1)
941
                if time.time() - start_time > timeout:
942
                    logging.error("Restart pod: %s timeout" % pod_name_tmp)
943
                    res = False
944
                    return res
945
                if ready_break:
946
                    break
947
    else:
948
        raise Exception("Pod: %s not found" % pod_name)
949
    follow = True
950
    pretty = True
951
    previous = True  # bool | Return previous terminated container logs. Defaults to false. (optional)
952
    since_seconds = 56  # int | A relative time in seconds before the current time from which to show logs. If this value precedes the time a pod was started, only logs since the pod start will be returned. If this value is in the future, no logs will be returned. Only one of sinceSeconds or sinceTime may be specified. (optional)
953
    timestamps = True  # bool | If true, add an RFC3339 or RFC3339Nano timestamp at the beginning of every line of log output. Defaults to false. (optional)
954
    container = "milvus"
955
    # start_time = time.time()
956
    # while time.time() - start_time <= timeout:
957
    #     try:
958
    #         api_response = v1.read_namespaced_pod_log(pod_name_tmp, namespace, container=container, follow=follow,
959
    #                                                 pretty=pretty, previous=previous, since_seconds=since_seconds,
960
    #                                                 timestamps=timestamps)
961
    #         logging.error(api_response)
962
    #         return res
963
    #     except Exception as e:
964
    #         logging.error("Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)
965
    #         # waiting for server start
966
    #         time.sleep(5)
967
    #         # res = False
968
    #         # return res
969
    # if time.time() - start_time > timeout:
970
    #     logging.error("Restart pod: %s timeout" % pod_name_tmp)
971
    #     res = False
972
    return res
973
974
975
class TestThread(threading.Thread):
976
    def __init__(self, target, args=()):
977
        threading.Thread.__init__(self, target=target, args=args)
978
979
    def run(self):
980
        self.exc = None
981
        try:
982
            super(TestThread, self).run()
983
        except BaseException as e:
984
            self.exc = e
985
986
    def join(self):
987
        super(TestThread, self).join()
988
        if self.exc:
989
            raise self.exc
990