mishards.service_handler.ServiceHandler.Compact()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 12
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 10
nop 3
dl 0
loc 12
rs 9.9
c 0
b 0
f 0
1
import logging
2
import time
3
import json
4
import ujson
5
6
import multiprocessing
7
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
8
from milvus.client import types as Types
9
from milvus import MetricType
10
11
from mishards import (db, exceptions)
12
from mishards.grpc_utils import mark_grpc_method
13
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
14
15
logger = logging.getLogger(__name__)
16
17
18
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
19
    MAX_NPROBE = 2048
20
    MAX_TOPK = 2048
21
22
    def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kwargs):
23
        self.collection_meta = {}
24
        self.error_handlers = {}
25
        self.tracer = tracer
26
        self.router = router
27
        self.max_workers = max_workers
28
29
    def _reduce(self, source_ids, ids, source_diss, diss, k, reverse):
30
        sort_f = (lambda x, y: x >= y) if reverse else (lambda x, y: x <= y)
31
32
        if sort_f(source_diss[k - 1], diss[0]):
33
            return source_ids, source_diss
34
        if sort_f(diss[k - 1], source_diss[0]):
35
            return ids, diss
36
37
        source_diss.extend(diss)
38
        diss_t = enumerate(source_diss)
39
        diss_m_rst = sorted(diss_t, key=lambda x: x[1], reverse=reverse)[:k]
40
        diss_m_out = [id_ for _, id_ in diss_m_rst]
41
42
        source_ids.extend(ids)
43
        id_m_out = [source_ids[i] for i, _ in diss_m_rst]
44
45
        return id_m_out, diss_m_out
46
47
    def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
48
        status = status_pb2.Status(error_code=status_pb2.SUCCESS,
49
                                   reason="Success")
50
        if not files_n_topk_results:
51
            return status, [], []
52
53
        merge_id_results = []
54
        merge_dis_results = []
55
56
        calc_time = time.time()
57
        for files_collection in files_n_topk_results:
58
            if isinstance(files_collection, tuple):
59
                status, _ = files_collection
60
                return status, [], []
61
62
            if files_collection.status.error_code != 0:
63
                return files_collection.status, [], []
64
65
            row_num = files_collection.row_num
66
            # row_num is equal to 0, result is empty
67
            if not row_num:
68
                continue
69
70
            ids = files_collection.ids
71
            diss = files_collection.distances  # distance collections
72
            # TODO: batch_len is equal to topk, may need to compare with topk
73
            batch_len = len(ids) // row_num
74
75
            for row_index in range(row_num):
76
                id_batch = ids[row_index * batch_len: (row_index + 1) * batch_len]
77
                dis_batch = diss[row_index * batch_len: (row_index + 1) * batch_len]
78
79
                if len(merge_id_results) < row_index:
80
                    raise ValueError("merge error")
81
                elif len(merge_id_results) == row_index:
82
                    # TODO: may bug here
83
                    merge_id_results.append(id_batch)
84
                    merge_dis_results.append(dis_batch)
85
                else:
86
                    merge_id_results[row_index], merge_dis_results[row_index] = \
87
                        self._reduce(merge_id_results[row_index], id_batch,
88
                                     merge_dis_results[row_index], dis_batch,
89
                                     batch_len,
90
                                     reverse)
91
92
        calc_time = time.time() - calc_time
93
        logger.info('Merge takes {}'.format(calc_time))
94
95
        id_mrege_list = []
96
        dis_mrege_list = []
97
98
        for id_results, dis_results in zip(merge_id_results, merge_dis_results):
99
            id_mrege_list.extend(id_results)
100
            dis_mrege_list.extend(dis_results)
101
102
        return status, id_mrege_list, dis_mrege_list
103
104
    def _do_query(self,
105
                  context,
106
                  collection_id,
107
                  collection_meta,
108
                  vectors,
109
                  topk,
110
                  search_params,
111
                  partition_tags=None,
112
                  **kwargs):
113
        metadata = kwargs.get('metadata', None)
114
115
        routing = {}
116
        p_span = None if self.tracer.empty else context.get_active_span(
117
        ).context
118
        with self.tracer.start_span('get_routing', child_of=p_span):
119
            routing = self.router.routing(collection_id,
120
                                          partition_tags=partition_tags,
121
                                          metadata=metadata)
122
        logger.info('Routing: {}'.format(routing))
123
124
        metadata = kwargs.get('metadata', None)
125
126
        all_topk_results = []
127
128
        with self.tracer.start_span('do_search', child_of=p_span) as span:
129
            if len(routing) == 0:
130
                ft = self.router.connection().search(collection_id, topk, vectors, list(partition_tags), search_params, _async=True)
131
                ret = ft.result(raw=True)
132
                all_topk_results.append(ret)
133
            else:
134
                futures = []
135
                start = time.time()
136
                for addr, files_tuple in routing.items():
137
                    search_file_ids, ud_file_ids = files_tuple
138
                    if ud_file_ids:
139
                        logger.debug(f"<{addr}> needed update segment ids {ud_file_ids}")
140
                    conn = self.router.query_conn(addr, metadata=metadata)
141
                    ud_file_ids and conn.reload_segments(collection_id, ud_file_ids)
142
                    span = kwargs.get('span', None)
143
                    span = span if span else (None if self.tracer.empty else
144
                                              context.get_active_span().context)
145
146
                    with self.tracer.start_span('search_{}'.format(addr),
147
                                                child_of=span):
148
                        future = conn.search_in_segment(collection_name=collection_id,
149
                                                              file_ids=search_file_ids,
150
                                                              query_records=vectors,
151
                                                              top_k=topk,
152
                                                              params=search_params, _async=True)
153
                        futures.append(future)
154
155
                for f in futures:
156
                    ret = f.result(raw=True)
157
                    all_topk_results.append(ret)
158
                logger.debug("Search in routing {} cost {} s".format(routing, time.time() - start))
159
160
        reverse = collection_meta.metric_type == Types.MetricType.IP
161
        with self.tracer.start_span('do_merge', child_of=p_span):
162
            return self._do_merge(all_topk_results,
163
                                  topk,
164
                                  reverse=reverse,
165
                                  metadata=metadata)
166
167
    def _create_collection(self, collection_schema):
168
        return self.router.connection().create_collection(collection_schema)
169
170
    @mark_grpc_method
171
    def CreateCollection(self, request, context):
172
        _status, unpacks = Parser.parse_proto_CollectionSchema(request)
173
174
        if not _status.OK():
175
            return status_pb2.Status(error_code=_status.code,
176
                                     reason=_status.message)
177
178
        _status, _collection_schema = unpacks
179
        # if _status.error_code != 0:
180
        #     logging.warning('[CreateCollection] collection schema error occurred: {}'.format(_status))
181
        #     return _status
182
183
        logger.info('CreateCollection {}'.format(_collection_schema['collection_name']))
184
185
        _status = self._create_collection(_collection_schema)
186
187
        return status_pb2.Status(error_code=_status.code,
188
                                 reason=_status.message)
189
190
    def _has_collection(self, collection_name, metadata=None):
191
        return self.router.connection(metadata=metadata).has_collection(collection_name)
192
193
    @mark_grpc_method
194
    def HasCollection(self, request, context):
195
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
196
197
        if not _status.OK():
198
            return milvus_pb2.BoolReply(status=status_pb2.Status(
199
                error_code=_status.code, reason=_status.message),
200
                bool_reply=False)
201
202
        logger.info('HasCollection {}'.format(_collection_name))
203
204
        _status, _bool = self._has_collection(_collection_name,
205
                                         metadata={'resp_class': milvus_pb2.BoolReply})
206
207
        return milvus_pb2.BoolReply(status=status_pb2.Status(
208
            error_code=_status.code, reason=_status.message),
209
            bool_reply=_bool)
210
211
    @mark_grpc_method
212
    def CreatePartition(self, request, context):
213
        _collection_name, _tag = Parser.parse_proto_PartitionParam(request)
214
        _status = self.router.connection().create_partition(_collection_name, _tag)
215
        return status_pb2.Status(error_code=_status.code,
216
                                 reason=_status.message)
217
218
    @mark_grpc_method
219
    def DropPartition(self, request, context):
220
        _collection_name, _tag = Parser.parse_proto_PartitionParam(request)
221
222
        _status = self.router.connection().drop_partition(_collection_name, _tag)
223
        return status_pb2.Status(error_code=_status.code,
224
                                 reason=_status.message)
225
226
    @mark_grpc_method
227
    def HasPartition(self, request, context):
228
        _collection_name, _tag = Parser.parse_proto_PartitionParam(request)
229
        _status, _ok = self.router.connection().has_partition(_collection_name, _tag)
230
        return milvus_pb2.BoolReply(status=status_pb2.Status(error_code=_status.code,
231
                                 reason=_status.message), bool_reply=_ok)
232
233
    @mark_grpc_method
234
    def ShowPartitions(self, request, context):
235
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
236
        if not _status.OK():
237
            return milvus_pb2.PartitionList(status=status_pb2.Status(
238
                error_code=_status.code, reason=_status.message),
239
                partition_array=[])
240
241
        logger.info('ShowPartitions {}'.format(_collection_name))
242
243
        _status, partition_array = self.router.connection().list_partitions(_collection_name)
244
245
        return milvus_pb2.PartitionList(status=status_pb2.Status(
246
            error_code=_status.code, reason=_status.message),
247
            partition_tag_array=[param.tag for param in partition_array])
248
249
    def _drop_collection(self, collection_name):
250
        return self.router.connection().drop_collection(collection_name)
251
252
    @mark_grpc_method
253
    def DropCollection(self, request, context):
254
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
255
256
        if not _status.OK():
257
            return status_pb2.Status(error_code=_status.code,
258
                                     reason=_status.message)
259
260
        logger.info('DropCollection {}'.format(_collection_name))
261
262
        _status = self._drop_collection(_collection_name)
263
264
        return status_pb2.Status(error_code=_status.code,
265
                                 reason=_status.message)
266
267
    def _create_index(self, collection_name, index_type, param):
268
        return self.router.connection().create_index(collection_name, index_type, param)
269
270
    @mark_grpc_method
271
    def CreateIndex(self, request, context):
272
        _status, unpacks = Parser.parse_proto_IndexParam(request)
273
274
        if not _status.OK():
275
            return status_pb2.Status(error_code=_status.code,
276
                                     reason=_status.message)
277
278
        _collection_name, _index_type, _index_param = unpacks
279
280
        logger.info('CreateIndex {}'.format(_collection_name))
281
282
        # TODO: interface create_collection incompleted
283
        _status = self._create_index(_collection_name, _index_type, _index_param)
284
285
        return status_pb2.Status(error_code=_status.code,
286
                                 reason=_status.message)
287
288
    def _add_vectors(self, param, metadata=None):
289
        return self.router.connection(metadata=metadata).insert(
290
            None, None, insert_param=param)
291
292
    @mark_grpc_method
293
    def Insert(self, request, context):
294
        logger.info('Insert')
295
        # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
296
        _status, _ids = self._add_vectors(
297
            metadata={'resp_class': milvus_pb2.VectorIds}, param=request)
298
        return milvus_pb2.VectorIds(status=status_pb2.Status(
299
            error_code=_status.code, reason=_status.message),
300
            vector_id_array=_ids)
301
302
    @mark_grpc_method
303
    def Search(self, request, context):
304
305
        metadata = {'resp_class': milvus_pb2.TopKQueryResult}
306
307
        collection_name = request.collection_name
308
309
        topk = request.topk
310
311
        if len(request.extra_params) == 0:
312
            raise exceptions.SearchParamError(message="Search param loss", metadata=metadata)
313
        params = ujson.loads(str(request.extra_params[0].value))
314
315
        logger.info('Search {}: topk={} params={}'.format(
316
            collection_name, topk, params))
317
318
        # if nprobe > self.MAX_NPROBE or nprobe <= 0:
319
        #     raise exceptions.InvalidArgumentError(
320
        #         message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)
321
322
        # if topk > self.MAX_TOPK or topk <= 0:
323
        #     raise exceptions.InvalidTopKError(
324
        #         message='Invalid topk: {}'.format(topk), metadata=metadata)
325
326
        collection_meta = self.collection_meta.get(collection_name, None)
327
328
        if not collection_meta:
329
            status, info = self.router.connection(
330
                metadata=metadata).get_collection_info(collection_name)
331
            if not status.OK():
332
                raise exceptions.CollectionNotFoundError(collection_name,
333
                                                         metadata=metadata)
334
335
            self.collection_meta[collection_name] = info
336
            collection_meta = info
337
338
        start = time.time()
339
340
        query_record_array = []
341
        if int(collection_meta.metric_type) >= MetricType.HAMMING.value:
342
            for query_record in request.query_record_array:
343
                query_record_array.append(bytes(query_record.binary_data))
344
        else:
345
            for query_record in request.query_record_array:
346
                query_record_array.append(list(query_record.float_data))
347
348
        status, id_results, dis_results = self._do_query(context,
349
                                                         collection_name,
350
                                                         collection_meta,
351
                                                         query_record_array,
352
                                                         topk,
353
                                                         params,
354
                                                         partition_tags=getattr(request, "partition_tag_array", []),
355
                                                         metadata=metadata)
356
357
        now = time.time()
358
        logger.info('SearchVector takes: {}'.format(now - start))
359
360
        topk_result_list = milvus_pb2.TopKQueryResult(
361
            status=status_pb2.Status(error_code=status.error_code,
362
                                     reason=status.reason),
363
            row_num=len(request.query_record_array) if len(id_results) else 0,
364
            ids=id_results,
365
            distances=dis_results)
366
        return topk_result_list
367
368
    @mark_grpc_method
369
    def SearchInFiles(self, request, context):
370
        raise NotImplemented()
371
372
    # @mark_grpc_method
373
    # def SearchByID(self, request, context):
374
    #     metadata = {'resp_class': milvus_pb2.TopKQueryResult}
375
    #
376
    #     collection_name = request.collection_name
377
    #
378
    #     topk = request.topk
379
    #
380
    #     if len(request.extra_params) == 0:
381
    #         raise exceptions.SearchParamError(message="Search param loss", metadata=metadata)
382
    #     params = ujson.loads(str(request.extra_params[0].value))
383
    #
384
    #     logger.info('Search {}: topk={} params={}'.format(
385
    #         collection_name, topk, params))
386
    #
387
    #     if topk > self.MAX_TOPK or topk <= 0:
388
    #         raise exceptions.InvalidTopKError(
389
    #             message='Invalid topk: {}'.format(topk), metadata=metadata)
390
    #
391
    #     collection_meta = self.collection_meta.get(collection_name, None)
392
    #
393
    #     if not collection_meta:
394
    #         status, info = self.router.connection(
395
    #             metadata=metadata).describe_collection(collection_name)
396
    #         if not status.OK():
397
    #             raise exceptions.CollectionNotFoundError(collection_name,
398
    #                                                      metadata=metadata)
399
    #
400
    #         self.collection_meta[collection_name] = info
401
    #         collection_meta = info
402
    #
403
    #     start = time.time()
404
    #
405
    #     query_record_array = []
406
    #     if int(collection_meta.metric_type) >= MetricType.HAMMING.value:
407
    #         for query_record in request.query_record_array:
408
    #             query_record_array.append(bytes(query_record.binary_data))
409
    #     else:
410
    #         for query_record in request.query_record_array:
411
    #             query_record_array.append(list(query_record.float_data))
412
    #
413
    #     partition_tags = getattr(request, "partition_tag_array", [])
414
    #     ids = getattr(request, "id_array", [])
415
    #     search_result = self.router.connection(metadata=metadata).search_by_ids(collection_name, ids, topk, partition_tags, params)
416
    #     # status, id_results, dis_results = self._do_query(context,
417
    #     #                                                  collection_name,
418
    #     #                                                  collection_meta,
419
    #     #                                                  query_record_array,
420
    #     #                                                  topk,
421
    #     #                                                  params,
422
    #     #                                                  partition_tags=getattr(request, "partition_tag_array", []),
423
    #     #                                                  metadata=metadata)
424
    #
425
    #     now = time.time()
426
    #     logger.info('SearchVector takes: {}'.format(now - start))
427
    #     return search_result
428
    #     #
429
    #     # topk_result_list = milvus_pb2.TopKQueryResult(
430
    #     #     status=status_pb2.Status(error_code=status.error_code,
431
    #     #                              reason=status.reason),
432
    #     #     row_num=len(request.query_record_array) if len(id_results) else 0,
433
    #     #     ids=id_results,
434
    #     #     distances=dis_results)
435
    #     # return topk_result_list
436
    #     # raise NotImplemented()
437
438
    def _describe_collection(self, collection_name, metadata=None):
439
        return self.router.connection(metadata=metadata).get_collection_info(collection_name)
440
441
    @mark_grpc_method
442
    def DescribeCollection(self, request, context):
443
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
444
445
        if not _status.OK():
446
            return milvus_pb2.CollectionSchema(status=status_pb2.Status(
447
                error_code=_status.code, reason=_status.message), )
448
449
        metadata = {'resp_class': milvus_pb2.CollectionSchema}
450
451
        logger.info('DescribeCollection {}'.format(_collection_name))
452
        _status, _collection = self._describe_collection(metadata=metadata,
453
                                               collection_name=_collection_name)
454
455
        if _status.OK():
456
            return milvus_pb2.CollectionSchema(
457
                collection_name=_collection_name,
458
                index_file_size=_collection.index_file_size,
459
                dimension=_collection.dimension,
460
                metric_type=_collection.metric_type,
461
                status=status_pb2.Status(error_code=_status.code,
462
                                         reason=_status.message),
463
            )
464
465
        return milvus_pb2.CollectionSchema(
466
            collection_name=_collection_name,
467
            status=status_pb2.Status(error_code=_status.code,
468
                                     reason=_status.message),
469
        )
470
471
    def _collection_info(self, collection_name, metadata=None):
472
        return self.router.connection(metadata=metadata).get_collection_stats(collection_name)
473
474
    @mark_grpc_method
475
    def ShowCollectionInfo(self, request, context):
476
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
477
478
        if not _status.OK():
479
            return milvus_pb2.CollectionInfo(status=status_pb2.Status(
480
                error_code=_status.code, reason=_status.message), )
481
482
        metadata = {'resp_class': milvus_pb2.CollectionInfo}
483
484
        _status, _info = self._collection_info(metadata=metadata, collection_name=_collection_name)
485
        _info_str = ujson.dumps(_info)
486
487
        if _status.OK():
488
            return milvus_pb2.CollectionInfo(
489
                status=status_pb2.Status(error_code=_status.code,
490
                                         reason=_status.message),
491
                json_info=_info_str
492
            )
493
494
        return milvus_pb2.CollectionInfo(
495
            status=status_pb2.Status(error_code=_status.code,
496
                                     reason=_status.message),
497
        )
498
499
    def _count_collection(self, collection_name, metadata=None):
500
        return self.router.connection(
501
            metadata=metadata).count_entities(collection_name)
502
503
    @mark_grpc_method
504
    def CountCollection(self, request, context):
505
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
506
507
        if not _status.OK():
508
            status = status_pb2.Status(error_code=_status.code,
509
                                       reason=_status.message)
510
511
            return milvus_pb2.CollectionRowCount(status=status)
512
513
        logger.info('CountCollection {}'.format(_collection_name))
514
515
        metadata = {'resp_class': milvus_pb2.CollectionRowCount}
516
        _status, _count = self._count_collection(_collection_name, metadata=metadata)
517
518
        return milvus_pb2.CollectionRowCount(
519
            status=status_pb2.Status(error_code=_status.code,
520
                                     reason=_status.message),
521
            collection_row_count=_count if isinstance(_count, int) else -1)
522
523
    def _get_server_version(self, metadata=None):
524
        return self.router.connection(metadata=metadata).server_version()
525
526
    def _cmd(self, cmd, metadata=None):
527
        return self.router.connection(metadata=metadata)._cmd(cmd)
528
529
    @mark_grpc_method
530
    def Cmd(self, request, context):
531
        _status, _cmd = Parser.parse_proto_Command(request)
532
        logger.info('Cmd: {}'.format(_cmd))
533
534
        if not _status.OK():
535
            return milvus_pb2.StringReply(status=status_pb2.Status(
536
                error_code=_status.code, reason=_status.message))
537
538
        metadata = {'resp_class': milvus_pb2.StringReply}
539
540
        if _cmd == 'conn_stats':
541
            stats = self.router.readonly_topo.stats()
542
            return milvus_pb2.StringReply(status=status_pb2.Status(
543
                error_code=status_pb2.SUCCESS),
544
                string_reply=json.dumps(stats, indent=2))
545
546
        # if _cmd == 'version':
547
        #     _status, _reply = self._get_server_version(metadata=metadata)
548
        # else:
549
        #     _status, _reply = self.router.connection(
550
        #         metadata=metadata).server_status()
551
        _status, _reply = self._cmd(_cmd, metadata=metadata)
552
553
        return milvus_pb2.StringReply(status=status_pb2.Status(
554
            error_code=_status.code, reason=_status.message),
555
            string_reply=_reply)
556
557
    def _show_collections(self, metadata=None):
558
        return self.router.connection(metadata=metadata).list_collections()
559
560
    @mark_grpc_method
561
    def ShowCollections(self, request, context):
562
        logger.info('ShowCollections')
563
        metadata = {'resp_class': milvus_pb2.CollectionName}
564
        _status, _results = self._show_collections(metadata=metadata)
565
566
        return milvus_pb2.CollectionNameList(status=status_pb2.Status(
567
            error_code=_status.code, reason=_status.message),
568
            collection_names=_results)
569
570
    def _preload_collection(self, collection_name, partition_tags):
571
        return self.router.connection().load_collection(collection_name, partition_tags)
572
573
    @mark_grpc_method
574
    def PreloadCollection(self, request, context):
575
        _status, _pack = Parser.parse_proto_PreloadCollectionParam(request)
576
577
        if not _status.OK():
578
            return status_pb2.Status(error_code=_status.code,
579
                                     reason=_status.message)
580
581
        _collection_name, _partition_tags = _pack
582
583
        logger.info('PreloadCollection {} | {}'.format(_collection_name, _partition_tags))
584
        _status = self._preload_collection(_collection_name, _partition_tags)
585
        return status_pb2.Status(error_code=_status.code,
586
                                 reason=_status.message)
587
588
    def ReloadSegments(self, request, context):
589
        raise NotImplementedError("Not implemented in mishards")
590
591
    def _describe_index(self, collection_name, metadata=None):
592
        return self.router.connection(metadata=metadata).get_index_info(collection_name)
593
594
    @mark_grpc_method
595
    def DescribeIndex(self, request, context):
596
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
597
598
        if not _status.OK():
599
            return milvus_pb2.IndexParam(status=status_pb2.Status(
600
                error_code=_status.code, reason=_status.message))
601
602
        metadata = {'resp_class': milvus_pb2.IndexParam}
603
604
        logger.info('DescribeIndex {}'.format(_collection_name))
605
        _status, _index_param = self._describe_index(collection_name=_collection_name,
606
                                                     metadata=metadata)
607
608
        if not _index_param:
609
            return milvus_pb2.IndexParam(status=status_pb2.Status(
610
                error_code=_status.code, reason=_status.message))
611
612
        _index_type = _index_param._index_type
613
614
        grpc_index = milvus_pb2.IndexParam(status=status_pb2.Status(
615
            error_code=_status.code, reason=_status.message),
616
            collection_name=_collection_name, index_type=_index_type)
617
618
        grpc_index.extra_params.add(key='params', value=ujson.dumps(_index_param._params))
619
        return grpc_index
620
621
    def _get_vectors_by_id(self, collection_name, ids, metadata):
622
        return self.router.connection(metadata=metadata).get_entity_by_id(collection_name, ids)
623
624
    @mark_grpc_method
625
    def GetVectorsByID(self, request, context):
626
        _status, unpacks = Parser.parse_proto_VectorIdentity(request)
627
        if not _status.OK():
628
            return status_pb2.Status(error_code=_status.code,
629
                                     reason=_status.message)
630
631
        metadata = {'resp_class': milvus_pb2.VectorsData}
632
633
        _collection_name, _ids = unpacks
634
        logger.info('GetVectorByID {}'.format(_collection_name))
635
        _status, vectors = self._get_vectors_by_id(_collection_name, _ids, metadata)
636
        _rpc_status = status_pb2.Status(error_code=_status.code, reason=_status.message)
637
        if not vectors:
638
            return milvus_pb2.VectorsData(status=_rpc_status, )
639
640
        if len(vectors) == 0:
641
            return milvus_pb2.VectorsData(status=_rpc_status, vectors_data=[])
642
        if isinstance(vectors[0], bytes):
643
            records = [milvus_pb2.RowRecord(binary_data=v) for v in vectors]
644
        else:
645
            records = [milvus_pb2.RowRecord(float_data=v) for v in vectors]
646
647
        response = milvus_pb2.VectorsData(status=_rpc_status)
648
        response.vectors_data.extend(records)
649
        return response
650
651
    def _get_vector_ids(self, collection_name, segment_name, metadata):
652
        return self.router.connection(metadata=metadata).list_id_in_segment(collection_name, segment_name)
653
654
    @mark_grpc_method
655
    def GetVectorIDs(self, request, context):
656
        _status, unpacks = Parser.parse_proto_GetVectorIDsParam(request)
657
658
        if not _status.OK():
659
            return status_pb2.Status(error_code=_status.code,
660
                                     reason=_status.message)
661
662
        metadata = {'resp_class': milvus_pb2.VectorIds}
663
664
        _collection_name, _segment_name = unpacks
665
        logger.info('GetVectorIDs {}'.format(_collection_name))
666
        _status, ids = self._get_vector_ids(_collection_name, _segment_name, metadata)
667
668
        if not ids:
669
            return milvus_pb2.VectorIds(status=status_pb2.Status(
670
                error_code=_status.code, reason=_status.message), )
671
672
        return milvus_pb2.VectorIds(status=status_pb2.Status(
673
            error_code=_status.code, reason=_status.message),
674
            vector_id_array=ids
675
        )
676
677
    def _delete_by_id(self, collection_name, id_array):
678
        return self.router.connection().delete_entity_by_id(collection_name, id_array)
679
680
    @mark_grpc_method
681
    def DeleteByID(self, request, context):
682
        _status, unpacks = Parser.parse_proto_DeleteByIDParam(request)
683
684
        if not _status.OK():
685
            logging.error('DeleteByID {}'.format(_status.message))
686
            return status_pb2.Status(error_code=_status.code,
687
                                     reason=_status.message)
688
689
        _collection_name, _ids = unpacks
690
        logger.info('DeleteByID {}'.format(_collection_name))
691
        _status = self._delete_by_id(_collection_name, _ids)
692
693
        return status_pb2.Status(error_code=_status.code,
694
                                 reason=_status.message)
695
696
    def _drop_index(self, collection_name):
697
        return self.router.connection().drop_index(collection_name)
698
699
    @mark_grpc_method
700
    def DropIndex(self, request, context):
701
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
702
703
        if not _status.OK():
704
            return status_pb2.Status(error_code=_status.code,
705
                                     reason=_status.message)
706
707
        logger.info('DropIndex {}'.format(_collection_name))
708
        _status = self._drop_index(_collection_name)
709
        return status_pb2.Status(error_code=_status.code,
710
                                 reason=_status.message)
711
712
    def _flush(self, collection_names):
713
        return self.router.connection().flush(collection_names)
714
715
    @mark_grpc_method
716
    def Flush(self, request, context):
717
        _status, _collection_names = Parser.parse_proto_FlushParam(request)
718
719
        if not _status.OK():
720
            return status_pb2.Status(error_code=_status.code,
721
                                     reason=_status.message)
722
723
        logger.info('Flush {}'.format(_collection_names))
724
        _status = self._flush(_collection_names)
725
        return status_pb2.Status(error_code=_status.code,
726
                                 reason=_status.message)
727
728
    def _compact(self, collection_name):
729
        return self.router.connection().compact(collection_name)
730
731
    @mark_grpc_method
732
    def Compact(self, request, context):
733
        _status, _collection_name = Parser.parse_proto_CollectionName(request)
734
735
        if not _status.OK():
736
            return status_pb2.Status(error_code=_status.code,
737
                                     reason=_status.message)
738
739
        logger.info('Compact {}'.format(_collection_name))
740
        _status = self._compact(_collection_name)
741
        return status_pb2.Status(error_code=_status.code,
742
                                 reason=_status.message)
743