ServiceHandler._do_query()   D
last analyzed

Complexity

Conditions 12

Size

Total Lines 62
Code Lines 54

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 12
eloc 54
nop 9
dl 0
loc 62
rs 4.8
c 0
b 0
f 0

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like mishards.service_handler.ServiceHandler._do_query() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import logging
2
import time
3
import json
4
import ujson
5
6
import multiprocessing
7
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
8
from milvus.client import types as Types
9
from milvus import MetricType
10
11
from mishards import (db, exceptions)
12
from mishards.grpc_utils import mark_grpc_method
13
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
14
15
logger = logging.getLogger(__name__)
16
17
18
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
19
    MAX_NPROBE = 2048
20
    MAX_TOPK = 2048
21
22
    def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kwargs):
23
        self.collection_meta = {}
24
        self.error_handlers = {}
25
        self.tracer = tracer
26
        self.router = router
27
        self.max_workers = max_workers
28
29
    def _reduce(self, source_ids, ids, source_diss, diss, k, reverse):
30
        sort_f = (lambda x, y: x >= y) if reverse else (lambda x, y: x <= y)
31
32
        if sort_f(source_diss[k - 1], diss[0]):
33
            return source_ids, source_diss
34
        if sort_f(diss[k - 1], source_diss[0]):
35
            return ids, diss
36
37
        source_diss.extend(diss)
38
        diss_t = enumerate(source_diss)
39
        diss_m_rst = sorted(diss_t, key=lambda x: x[1], reverse=reverse)[:k]
40
        diss_m_out = [id_ for _, id_ in diss_m_rst]
41
42
        source_ids.extend(ids)
43
        id_m_out = [source_ids[i] for i, _ in diss_m_rst]
44
45
        return id_m_out, diss_m_out
46
47
    def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
48
        status = status_pb2.Status(error_code=status_pb2.SUCCESS,
49
                                   reason="Success")
50
        if not files_n_topk_results:
51
            return status, [], []
52
53
        merge_id_results = []
54
        merge_dis_results = []
55
56
        calc_time = time.time()
57
        for files_collection in files_n_topk_results:
58
            if isinstance(files_collection, tuple):
59
                status, _ = files_collection
60
                return status, [], []
61
62
            if files_collection.status.error_code != 0:
63
                return files_collection.status, [], []
64
65
            row_num = files_collection.row_num
66
            # row_num is equal to 0, result is empty
67
            if not row_num:
68
                continue
69
70
            ids = files_collection.ids
71
            diss = files_collection.distances  # distance collections
72
            # TODO: batch_len is equal to topk, may need to compare with topk
73
            batch_len = len(ids) // row_num
74
75
            for row_index in range(row_num):
76
                id_batch = ids[row_index * batch_len: (row_index + 1) * batch_len]
77
                dis_batch = diss[row_index * batch_len: (row_index + 1) * batch_len]
78
79
                if len(merge_id_results) < row_index:
80
                    raise ValueError("merge error")
81
                elif len(merge_id_results) == row_index:
82
                    # TODO: may bug here
83
                    merge_id_results.append(id_batch)
84
                    merge_dis_results.append(dis_batch)
85
                else:
86
                    merge_id_results[row_index], merge_dis_results[row_index] = \
87
                        self._reduce(merge_id_results[row_index], id_batch,
88
                                     merge_dis_results[row_index], dis_batch,
89
                                     batch_len,
90
                                     reverse)
91
92
        calc_time = time.time() - calc_time
93
        logger.info('Merge takes {}'.format(calc_time))
94
95
        id_mrege_list = []
96
        dis_mrege_list = []
97
98
        for id_results, dis_results in zip(merge_id_results, merge_dis_results):
99
            id_mrege_list.extend(id_results)
100
            dis_mrege_list.extend(dis_results)
101
102
        return status, id_mrege_list, dis_mrege_list
103
104
    def _do_query(self,
105
                  context,
106
                  collection_id,
107
                  collection_meta,
108
                  vectors,
109
                  topk,
110
                  search_params,
111
                  partition_tags=None,
112
                  **kwargs):
113
        metadata = kwargs.get('metadata', None)
114
115
        routing = {}
116
        p_span = None if self.tracer.empty else context.get_active_span(
117
        ).context
118
        with self.tracer.start_span('get_routing', child_of=p_span):
119
            routing = self.router.routing(collection_id,
120
                                          partition_tags=partition_tags,
121
                                          metadata=metadata)
122
        logger.info('Routing: {}'.format(routing))
123
124
        metadata = kwargs.get('metadata', None)
125
126
        all_topk_results = []
127
128
        with self.tracer.start_span('do_search', child_of=p_span) as span:
129
            if len(routing) == 0:
130
                ft = self.router.connection().search(collection_id, topk, vectors, list(partition_tags), search_params, _async=True)
131
                ret = ft.result(raw=True)
132
                all_topk_results.append(ret)
133
            else:
134
                futures = []
135
                start = time.time()
136
                for addr, files_tuple in routing.items():
137
                    search_file_ids, ud_file_ids = files_tuple
138
                    if ud_file_ids:
139
                        logger.debug(f"<{addr}> needed update segment ids {ud_file_ids}")
140
                    conn = self.router.query_conn(addr, metadata=metadata)
141
                    ud_file_ids and conn.reload_segments(collection_id, ud_file_ids)
142
                    span = kwargs.get('span', None)
143
                    span = span if span else (None if self.tracer.empty else
144
                                              context.get_active_span().context)
145
146
                    with self.tracer.start_span('search_{}'.format(addr),
147
                                                child_of=span):
148
                        future = conn.search_in_segment(collection_name=collection_id,
149
                                                              file_ids=search_file_ids,
150
                                                              query_records=vectors,
151
                                                              top_k=topk,
152
                                                              params=search_params, _async=True)
153
                        futures.append(future)
154
155
                for f in futures:
156
                    ret = f.result(raw=True)
157
                    all_topk_results.append(ret)
158
                logger.debug("Search in routing {} cost {} s".format(routing, time.time() - start))
159
160
        reverse = collection_meta.metric_type == Types.MetricType.IP
161
        with self.tracer.start_span('do_merge', child_of=p_span):
162
            return self._do_merge(all_topk_results,
163
                                  topk,
164
                                  reverse=reverse,
165
                                  metadata=metadata)
166
167
    def _create_collection(self, collection_schema):
168
        return self.router.connection().create_collection(collection_schema)
169
170
    @mark_grpc_method
171
    def CreateCollection(self, request, context):
172
        _status, unpacks = Parser.parse_proto_CollectionSchema(request)
173
174
        if not _status.OK():
175
            return status_pb2.Status(error_code=_status.code,
176
                                     reason=_status.message)
177
178
        _status, _collection_schema = unpacks
179
        # if _status.error_code != 0:
180
        #     logging.warning('[CreateCollection] collection schema error occurred: {}'.format(_status))
181
        #     return _status
182
183
        logger.info('CreateCollection {}'.format(_collection_schema['collection_name']))
184
185
        _status = self._create_collection(_collection_schema)
186
187
        return status_pb2.Status(error_code=_status.code,
188
                                 reason=_status.message)
189
190
    def _has_collection(self, collection_name, metadata=None):
191
        return self.router.connection(metadata=metadata).has_collection(collection_name)
192
193
    @mark_grpc_method
194
    def HasCollection(self, request, context):
195
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
196
197
        if not _status.OK():
198
            return milvus_pb2.BoolReply(status=status_pb2.Status(
199
                error_code=_status.code, reason=_status.message),
200
                bool_reply=False)
201
202
        logger.info('HasCollection {}'.format(_collection_name))
203
204
        _status, _bool = self._has_collection(_collection_name,
205
                                         metadata={'resp_class': milvus_pb2.BoolReply})
206
207
        return milvus_pb2.BoolReply(status=status_pb2.Status(
208
            error_code=_status.code, reason=_status.message),
209
            bool_reply=_bool)
210
211
    @mark_grpc_method
212
    def CreatePartition(self, request, context):
213
        _collection_name, _tag = Parser.parse_proto_PartitionParam(request)
214
        _status = self.router.connection().create_partition(_collection_name, _tag)
215
        return status_pb2.Status(error_code=_status.code,
216
                                 reason=_status.message)
217
218
    @mark_grpc_method
219
    def DropPartition(self, request, context):
220
        _collection_name, _tag = Parser.parse_proto_PartitionParam(request)
221
222
        _status = self.router.connection().drop_partition(_collection_name, _tag)
223
        return status_pb2.Status(error_code=_status.code,
224
                                 reason=_status.message)
225
226
    @mark_grpc_method
227
    def HasPartition(self, request, context):
228
        _collection_name, _tag = Parser.parse_proto_PartitionParam(request)
229
        _status, _ok = self.router.connection().has_partition(_collection_name, _tag)
230
        return milvus_pb2.BoolReply(status=status_pb2.Status(error_code=_status.code,
231
                                 reason=_status.message), bool_reply=_ok)
232
233
    @mark_grpc_method
234
    def ShowPartitions(self, request, context):
235
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
236
        if not _status.OK():
237
            return milvus_pb2.PartitionList(status=status_pb2.Status(
238
                error_code=_status.code, reason=_status.message),
239
                partition_array=[])
240
241
        logger.info('ShowPartitions {}'.format(_collection_name))
242
243
        _status, partition_array = self.router.connection().list_partitions(_collection_name)
244
245
        return milvus_pb2.PartitionList(status=status_pb2.Status(
246
            error_code=_status.code, reason=_status.message),
247
            partition_tag_array=[param.tag for param in partition_array])
248
249
    def _drop_collection(self, collection_name):
250
        return self.router.connection().drop_collection(collection_name)
251
252
    @mark_grpc_method
253
    def DropCollection(self, request, context):
254
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
255
256
        if not _status.OK():
257
            return status_pb2.Status(error_code=_status.code,
258
                                     reason=_status.message)
259
260
        logger.info('DropCollection {}'.format(_collection_name))
261
262
        _status = self._drop_collection(_collection_name)
263
264
        return status_pb2.Status(error_code=_status.code,
265
                                 reason=_status.message)
266
267
    def _create_index(self, collection_name, index_type, param):
268
        return self.router.connection().create_index(collection_name, index_type, param)
269
270
    @mark_grpc_method
271
    def CreateIndex(self, request, context):
272
        _status, unpacks = Parser.parse_proto_IndexParam(request)
273
274
        if not _status.OK():
275
            return status_pb2.Status(error_code=_status.code,
276
                                     reason=_status.message)
277
278
        _collection_name, _index_type, _index_param = unpacks
279
280
        logger.info('CreateIndex {}'.format(_collection_name))
281
282
        # TODO: interface create_collection incompleted
283
        _status = self._create_index(_collection_name, _index_type, _index_param)
284
285
        return status_pb2.Status(error_code=_status.code,
286
                                 reason=_status.message)
287
288
    def _add_vectors(self, param, metadata=None):
289
        return self.router.connection(metadata=metadata).insert(
290
            None, None, insert_param=param)
291
292
    @mark_grpc_method
293
    def Insert(self, request, context):
294
        logger.info('Insert')
295
        # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
296
        _status, _ids = self._add_vectors(
297
            metadata={'resp_class': milvus_pb2.VectorIds}, param=request)
298
        return milvus_pb2.VectorIds(status=status_pb2.Status(
299
            error_code=_status.code, reason=_status.message),
300
            vector_id_array=_ids)
301
302
    @mark_grpc_method
303
    def Search(self, request, context):
304
305
        metadata = {'resp_class': milvus_pb2.TopKQueryResult}
306
307
        collection_name = request.collection_name
308
309
        topk = request.topk
310
311
        if len(request.extra_params) == 0:
312
            raise exceptions.SearchParamError(message="Search param loss", metadata=metadata)
313
        params = ujson.loads(str(request.extra_params[0].value))
314
315
        logger.info('Search {}: topk={} params={}'.format(
316
            collection_name, topk, params))
317
318
        # if nprobe > self.MAX_NPROBE or nprobe <= 0:
319
        #     raise exceptions.InvalidArgumentError(
320
        #         message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)
321
322
        # if topk > self.MAX_TOPK or topk <= 0:
323
        #     raise exceptions.InvalidTopKError(
324
        #         message='Invalid topk: {}'.format(topk), metadata=metadata)
325
326
        collection_meta = self.collection_meta.get(collection_name, None)
327
328
        if not collection_meta:
329
            status, info = self.router.connection(
330
                metadata=metadata).get_collection_info(collection_name)
331
            if not status.OK():
332
                raise exceptions.CollectionNotFoundError(collection_name,
333
                                                         metadata=metadata)
334
335
            self.collection_meta[collection_name] = info
336
            collection_meta = info
337
338
        start = time.time()
339
340
        query_record_array = []
341
        if int(collection_meta.metric_type) >= MetricType.HAMMING.value:
342
            for query_record in request.query_record_array:
343
                query_record_array.append(bytes(query_record.binary_data))
344
        else:
345
            for query_record in request.query_record_array:
346
                query_record_array.append(list(query_record.float_data))
347
348
        status, id_results, dis_results = self._do_query(context,
349
                                                         collection_name,
350
                                                         collection_meta,
351
                                                         query_record_array,
352
                                                         topk,
353
                                                         params,
354
                                                         partition_tags=getattr(request, "partition_tag_array", []),
355
                                                         metadata=metadata)
356
357
        now = time.time()
358
        logger.info('SearchVector takes: {}'.format(now - start))
359
360
        topk_result_list = milvus_pb2.TopKQueryResult(
361
            status=status_pb2.Status(error_code=status.error_code,
362
                                     reason=status.reason),
363
            row_num=len(request.query_record_array) if len(id_results) else 0,
364
            ids=id_results,
365
            distances=dis_results)
366
        return topk_result_list
367
368
    @mark_grpc_method
369
    def SearchInFiles(self, request, context):
370
        raise NotImplemented()
371
372
    # @mark_grpc_method
373
    # def SearchByID(self, request, context):
374
    #     metadata = {'resp_class': milvus_pb2.TopKQueryResult}
375
    #
376
    #     collection_name = request.collection_name
377
    #
378
    #     topk = request.topk
379
    #
380
    #     if len(request.extra_params) == 0:
381
    #         raise exceptions.SearchParamError(message="Search param loss", metadata=metadata)
382
    #     params = ujson.loads(str(request.extra_params[0].value))
383
    #
384
    #     logger.info('Search {}: topk={} params={}'.format(
385
    #         collection_name, topk, params))
386
    #
387
    #     if topk > self.MAX_TOPK or topk <= 0:
388
    #         raise exceptions.InvalidTopKError(
389
    #             message='Invalid topk: {}'.format(topk), metadata=metadata)
390
    #
391
    #     collection_meta = self.collection_meta.get(collection_name, None)
392
    #
393
    #     if not collection_meta:
394
    #         status, info = self.router.connection(
395
    #             metadata=metadata).describe_collection(collection_name)
396
    #         if not status.OK():
397
    #             raise exceptions.CollectionNotFoundError(collection_name,
398
    #                                                      metadata=metadata)
399
    #
400
    #         self.collection_meta[collection_name] = info
401
    #         collection_meta = info
402
    #
403
    #     start = time.time()
404
    #
405
    #     query_record_array = []
406
    #     if int(collection_meta.metric_type) >= MetricType.HAMMING.value:
407
    #         for query_record in request.query_record_array:
408
    #             query_record_array.append(bytes(query_record.binary_data))
409
    #     else:
410
    #         for query_record in request.query_record_array:
411
    #             query_record_array.append(list(query_record.float_data))
412
    #
413
    #     partition_tags = getattr(request, "partition_tag_array", [])
414
    #     ids = getattr(request, "id_array", [])
415
    #     search_result = self.router.connection(metadata=metadata).search_by_ids(collection_name, ids, topk, partition_tags, params)
416
    #     # status, id_results, dis_results = self._do_query(context,
417
    #     #                                                  collection_name,
418
    #     #                                                  collection_meta,
419
    #     #                                                  query_record_array,
420
    #     #                                                  topk,
421
    #     #                                                  params,
422
    #     #                                                  partition_tags=getattr(request, "partition_tag_array", []),
423
    #     #                                                  metadata=metadata)
424
    #
425
    #     now = time.time()
426
    #     logger.info('SearchVector takes: {}'.format(now - start))
427
    #     return search_result
428
    #     #
429
    #     # topk_result_list = milvus_pb2.TopKQueryResult(
430
    #     #     status=status_pb2.Status(error_code=status.error_code,
431
    #     #                              reason=status.reason),
432
    #     #     row_num=len(request.query_record_array) if len(id_results) else 0,
433
    #     #     ids=id_results,
434
    #     #     distances=dis_results)
435
    #     # return topk_result_list
436
    #     # raise NotImplemented()
437
438
    def _describe_collection(self, collection_name, metadata=None):
439
        return self.router.connection(metadata=metadata).get_collection_info(collection_name)
440
441
    @mark_grpc_method
442
    def DescribeCollection(self, request, context):
443
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
444
445
        if not _status.OK():
446
            return milvus_pb2.CollectionSchema(status=status_pb2.Status(
447
                error_code=_status.code, reason=_status.message), )
448
449
        metadata = {'resp_class': milvus_pb2.CollectionSchema}
450
451
        logger.info('DescribeCollection {}'.format(_collection_name))
452
        _status, _collection = self._describe_collection(metadata=metadata,
453
                                               collection_name=_collection_name)
454
455
        if _status.OK():
456
            return milvus_pb2.CollectionSchema(
457
                collection_name=_collection_name,
458
                index_file_size=_collection.index_file_size,
459
                dimension=_collection.dimension,
460
                metric_type=_collection.metric_type,
461
                status=status_pb2.Status(error_code=_status.code,
462
                                         reason=_status.message),
463
            )
464
465
        return milvus_pb2.CollectionSchema(
466
            collection_name=_collection_name,
467
            status=status_pb2.Status(error_code=_status.code,
468
                                     reason=_status.message),
469
        )
470
471
    def _collection_info(self, collection_name, metadata=None):
472
        return self.router.connection(metadata=metadata).get_collection_stats(collection_name)
473
474
    @mark_grpc_method
475
    def ShowCollectionInfo(self, request, context):
476
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
477
478
        if not _status.OK():
479
            return milvus_pb2.CollectionInfo(status=status_pb2.Status(
480
                error_code=_status.code, reason=_status.message), )
481
482
        metadata = {'resp_class': milvus_pb2.CollectionInfo}
483
484
        _status, _info = self._collection_info(metadata=metadata, collection_name=_collection_name)
485
        _info_str = ujson.dumps(_info)
486
487
        if _status.OK():
488
            return milvus_pb2.CollectionInfo(
489
                status=status_pb2.Status(error_code=_status.code,
490
                                         reason=_status.message),
491
                json_info=_info_str
492
            )
493
494
        return milvus_pb2.CollectionInfo(
495
            status=status_pb2.Status(error_code=_status.code,
496
                                     reason=_status.message),
497
        )
498
499
    def _count_collection(self, collection_name, metadata=None):
500
        return self.router.connection(
501
            metadata=metadata).count_entities(collection_name)
502
503
    @mark_grpc_method
504
    def CountCollection(self, request, context):
505
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
506
507
        if not _status.OK():
508
            status = status_pb2.Status(error_code=_status.code,
509
                                       reason=_status.message)
510
511
            return milvus_pb2.CollectionRowCount(status=status)
512
513
        logger.info('CountCollection {}'.format(_collection_name))
514
515
        metadata = {'resp_class': milvus_pb2.CollectionRowCount}
516
        _status, _count = self._count_collection(_collection_name, metadata=metadata)
517
518
        return milvus_pb2.CollectionRowCount(
519
            status=status_pb2.Status(error_code=_status.code,
520
                                     reason=_status.message),
521
            collection_row_count=_count if isinstance(_count, int) else -1)
522
523
    def _get_server_version(self, metadata=None):
524
        return self.router.connection(metadata=metadata).server_version()
525
526
    def _cmd(self, cmd, metadata=None):
527
        return self.router.connection(metadata=metadata)._cmd(cmd)
528
529
    @mark_grpc_method
530
    def Cmd(self, request, context):
531
        _status, _cmd = Parser.parse_proto_Command(request)
532
        logger.info('Cmd: {}'.format(_cmd))
533
534
        if not _status.OK():
535
            return milvus_pb2.StringReply(status=status_pb2.Status(
536
                error_code=_status.code, reason=_status.message))
537
538
        metadata = {'resp_class': milvus_pb2.StringReply}
539
540
        if _cmd == 'conn_stats':
541
            stats = self.router.readonly_topo.stats()
542
            return milvus_pb2.StringReply(status=status_pb2.Status(
543
                error_code=status_pb2.SUCCESS),
544
                string_reply=json.dumps(stats, indent=2))
545
546
        # if _cmd == 'version':
547
        #     _status, _reply = self._get_server_version(metadata=metadata)
548
        # else:
549
        #     _status, _reply = self.router.connection(
550
        #         metadata=metadata).server_status()
551
        _status, _reply = self._cmd(_cmd, metadata=metadata)
552
553
        return milvus_pb2.StringReply(status=status_pb2.Status(
554
            error_code=_status.code, reason=_status.message),
555
            string_reply=_reply)
556
557
    def _show_collections(self, metadata=None):
558
        return self.router.connection(metadata=metadata).list_collections()
559
560
    @mark_grpc_method
561
    def ShowCollections(self, request, context):
562
        logger.info('ShowCollections')
563
        metadata = {'resp_class': milvus_pb2.CollectionName}
564
        _status, _results = self._show_collections(metadata=metadata)
565
566
        return milvus_pb2.CollectionNameList(status=status_pb2.Status(
567
            error_code=_status.code, reason=_status.message),
568
            collection_names=_results)
569
570
    def _preload_collection(self, collection_name, partition_tags):
571
        return self.router.connection().load_collection(collection_name, partition_tags)
572
573
    @mark_grpc_method
574
    def PreloadCollection(self, request, context):
575
        _status, _pack = Parser.parse_proto_PreloadCollectionParam(request)
576
577
        if not _status.OK():
578
            return status_pb2.Status(error_code=_status.code,
579
                                     reason=_status.message)
580
581
        _collection_name, _partition_tags = _pack
582
583
        logger.info('PreloadCollection {} | {}'.format(_collection_name, _partition_tags))
584
        _status = self._preload_collection(_collection_name, _partition_tags)
585
        return status_pb2.Status(error_code=_status.code,
586
                                 reason=_status.message)
587
588
    def ReloadSegments(self, request, context):
589
        raise NotImplementedError("Not implemented in mishards")
590
591
    def _describe_index(self, collection_name, metadata=None):
592
        return self.router.connection(metadata=metadata).get_index_info(collection_name)
593
594
    @mark_grpc_method
595
    def DescribeIndex(self, request, context):
596
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
597
598
        if not _status.OK():
599
            return milvus_pb2.IndexParam(status=status_pb2.Status(
600
                error_code=_status.code, reason=_status.message))
601
602
        metadata = {'resp_class': milvus_pb2.IndexParam}
603
604
        logger.info('DescribeIndex {}'.format(_collection_name))
605
        _status, _index_param = self._describe_index(collection_name=_collection_name,
606
                                                     metadata=metadata)
607
608
        if not _index_param:
609
            return milvus_pb2.IndexParam(status=status_pb2.Status(
610
                error_code=_status.code, reason=_status.message))
611
612
        _index_type = _index_param._index_type
613
614
        grpc_index = milvus_pb2.IndexParam(status=status_pb2.Status(
615
            error_code=_status.code, reason=_status.message),
616
            collection_name=_collection_name, index_type=_index_type)
617
618
        grpc_index.extra_params.add(key='params', value=ujson.dumps(_index_param._params))
619
        return grpc_index
620
621
    def _get_vectors_by_id(self, collection_name, ids, metadata):
622
        return self.router.connection(metadata=metadata).get_entity_by_id(collection_name, ids)
623
624
    @mark_grpc_method
625
    def GetVectorsByID(self, request, context):
626
        _status, unpacks = Parser.parse_proto_VectorIdentity(request)
627
        if not _status.OK():
628
            return status_pb2.Status(error_code=_status.code,
629
                                     reason=_status.message)
630
631
        metadata = {'resp_class': milvus_pb2.VectorsData}
632
633
        _collection_name, _ids = unpacks
634
        logger.info('GetVectorByID {}'.format(_collection_name))
635
        _status, vectors = self._get_vectors_by_id(_collection_name, _ids, metadata)
636
        _rpc_status = status_pb2.Status(error_code=_status.code, reason=_status.message)
637
        if not vectors:
638
            return milvus_pb2.VectorsData(status=_rpc_status, )
639
640
        if len(vectors) == 0:
641
            return milvus_pb2.VectorsData(status=_rpc_status, vectors_data=[])
642
        if isinstance(vectors[0], bytes):
643
            records = [milvus_pb2.RowRecord(binary_data=v) for v in vectors]
644
        else:
645
            records = [milvus_pb2.RowRecord(float_data=v) for v in vectors]
646
647
        response = milvus_pb2.VectorsData(status=_rpc_status)
648
        response.vectors_data.extend(records)
649
        return response
650
651
    def _get_vector_ids(self, collection_name, segment_name, metadata):
652
        return self.router.connection(metadata=metadata).list_id_in_segment(collection_name, segment_name)
653
654
    @mark_grpc_method
655
    def GetVectorIDs(self, request, context):
656
        _status, unpacks = Parser.parse_proto_GetVectorIDsParam(request)
657
658
        if not _status.OK():
659
            return status_pb2.Status(error_code=_status.code,
660
                                     reason=_status.message)
661
662
        metadata = {'resp_class': milvus_pb2.VectorIds}
663
664
        _collection_name, _segment_name = unpacks
665
        logger.info('GetVectorIDs {}'.format(_collection_name))
666
        _status, ids = self._get_vector_ids(_collection_name, _segment_name, metadata)
667
668
        if not ids:
669
            return milvus_pb2.VectorIds(status=status_pb2.Status(
670
                error_code=_status.code, reason=_status.message), )
671
672
        return milvus_pb2.VectorIds(status=status_pb2.Status(
673
            error_code=_status.code, reason=_status.message),
674
            vector_id_array=ids
675
        )
676
677
    def _delete_by_id(self, collection_name, id_array):
678
        return self.router.connection().delete_entity_by_id(collection_name, id_array)
679
680
    @mark_grpc_method
681
    def DeleteByID(self, request, context):
682
        _status, unpacks = Parser.parse_proto_DeleteByIDParam(request)
683
684
        if not _status.OK():
685
            logging.error('DeleteByID {}'.format(_status.message))
686
            return status_pb2.Status(error_code=_status.code,
687
                                     reason=_status.message)
688
689
        _collection_name, _ids = unpacks
690
        logger.info('DeleteByID {}'.format(_collection_name))
691
        _status = self._delete_by_id(_collection_name, _ids)
692
693
        return status_pb2.Status(error_code=_status.code,
694
                                 reason=_status.message)
695
696
    def _drop_index(self, collection_name):
697
        return self.router.connection().drop_index(collection_name)
698
699
    @mark_grpc_method
700
    def DropIndex(self, request, context):
701
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
702
703
        if not _status.OK():
704
            return status_pb2.Status(error_code=_status.code,
705
                                     reason=_status.message)
706
707
        logger.info('DropIndex {}'.format(_collection_name))
708
        _status = self._drop_index(_collection_name)
709
        return status_pb2.Status(error_code=_status.code,
710
                                 reason=_status.message)
711
712
    def _flush(self, collection_names):
713
        return self.router.connection().flush(collection_names)
714
715
    @mark_grpc_method
716
    def Flush(self, request, context):
717
        _status, _collection_names = Parser.parse_proto_FlushParam(request)
718
719
        if not _status.OK():
720
            return status_pb2.Status(error_code=_status.code,
721
                                     reason=_status.message)
722
723
        logger.info('Flush {}'.format(_collection_names))
724
        _status = self._flush(_collection_names)
725
        return status_pb2.Status(error_code=_status.code,
726
                                 reason=_status.message)
727
728
    def _compact(self, collection_name):
729
        return self.router.connection().compact(collection_name)
730
731
    @mark_grpc_method
732
    def Compact(self, request, context):
733
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
734
735
        if not _status.OK():
736
            return status_pb2.Status(error_code=_status.code,
737
                                     reason=_status.message)
738
739
        logger.info('Compact {}'.format(_collection_name))
740
        _status = self._compact(_collection_name)
741
        return status_pb2.Status(error_code=_status.code,
742
                                 reason=_status.message)
743