Passed
Pull Request — master (#4764)
by
unknown
02:04
created

RequestHeaderValidatorInterceptor.intercept_service()   A

Complexity

Conditions 2

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 3
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
import os
2
import logging
3
import sys
4
import grpc
5
import time
6
import socket
7
import inspect
8
from urllib.parse import urlparse
9
from functools import wraps
10
from concurrent import futures
11
from grpc._cython import cygrpc
12
import milvus
13
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
14
from mishards.grpc_utils import is_grpc_method
15
from mishards.service_handler import ServiceHandler
16
from mishards import settings
17
logger = logging.getLogger(__name__)
18
19
20
def _unary_unary_rpc_terminator(code, details):
21
22
    def terminate(ignored_request, context):
23
        context.abort(code, details)
24
25
    return grpc.unary_unary_rpc_method_handler(terminate)
26
27
28
class RequestHeaderValidatorInterceptor(grpc.ServerInterceptor):
29
    def __init__(self, header, value, code, details):
30
        self._header = header
31
        self._value = value
32
        self._terminator = _unary_unary_rpc_terminator(code, details)
33
34
    def intercept_service(self, continuation, handler_call_details):
35
        if (self._header, self._value) in handler_call_details.invocation_metadata:
36
            return continuation(handler_call_details)
37
        else:
38
            return self._terminator
39
40
41
class Server:
42
    def __init__(self):
43
        self.pre_run_handlers = set()
44
        self.grpc_methods = set()
45
        self.error_handlers = {}
46
        self.exit_flag = False
47
48
    def init_app(self,
49
                 writable_topo,
50
                 readonly_topo,
51
                 tracer,
52
                 router,
53
                 discover,
54
                 port=19530,
55
                 max_workers=10,
56
                 **kwargs):
57
        self.port = int(port)
58
        self.writable_topo = writable_topo
59
        self.readonly_topo = readonly_topo
60
        self.tracer = tracer
61
        self.router = router
62
        self.discover = discover
63
64
        token = os.getenv("MISHARDS_TOKEN")
65
        logger.debug(f"Mishards token is: {token}")
66
        logger.debug('Init grpc server with max_workers: {}'.format(max_workers))
67
        header_validator = RequestHeaderValidatorInterceptor(
68
            'token', token, grpc.StatusCode.UNAUTHENTICATED,
69
            'Access denied!')
70
        self.server_impl = grpc.server(
71
            thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
72
            interceptors=(header_validator,),
73
            options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
74
                     (cygrpc.ChannelArgKey.max_receive_message_length, -1)])
75
76
        self.server_impl = self.tracer.decorate(self.server_impl)
77
78
        self.register_pre_run_handler(self.pre_run_handler)
79
80
    def pre_run_handler(self):
81
        woserver = settings.WOSERVER
82
        url = urlparse(woserver)
83
        ip = socket.gethostbyname(url.hostname)
84
        socket.inet_pton(socket.AF_INET, ip)
85
        _, group = self.writable_topo.create('default')
86
        group.create(name='WOSERVER', uri='{}://{}:{}'.format(url.scheme, ip, url.port or 80))
87
88
    def register_pre_run_handler(self, func):
89
        logger.info('Regiterring {} into server pre_run_handlers'.format(func))
90
        self.pre_run_handlers.add(func)
91
        return func
92
93
    def wrap_method_with_errorhandler(self, func):
94
        @wraps(func)
95
        def wrapper(*args, **kwargs):
96
            try:
97
                return func(*args, **kwargs)
98
            except Exception as e:
99
                if e.__class__ in self.error_handlers:
100
                    return self.error_handlers[e.__class__](e)
101
                raise
102
103
        return wrapper
104
105
    def errorhandler(self, exception):
106
        if inspect.isclass(exception) and issubclass(exception, Exception):
107
108
            def wrapper(func):
109
                self.error_handlers[exception] = func
110
                return func
111
112
            return wrapper
113
        return exception
114
115
    def on_pre_run(self):
116
        for handler in self.pre_run_handlers:
117
            handler()
118
        return self.discover.start()
119
120
    def start(self, port=None):
121
        handler_class = self.decorate_handler(ServiceHandler)
122
        add_MilvusServiceServicer_to_server(
123
            handler_class(tracer=self.tracer,
124
                          router=self.router), self.server_impl)
125
        self.server_impl.add_insecure_port("[::]:{}".format(
126
            str(port or self.port)))
127
        self.server_impl.start()
128
129
    def run(self, port):
130
        logger.info('Milvus server start ......')
131
        port = port or self.port
132
        ok = self.on_pre_run()
133
134
        if not ok:
135
            logger.error('Terminate server due to error found in on_pre_run')
136
            sys.exit(1)
137
138
        self.start(port)
139
        logger.info(f'Server Version: {settings.SERVER_VERSIONS[-1]}')
140
        logger.info(f'Python SDK Version: {milvus.__version__}')
141
        logger.info('Listening on port {}'.format(port))
142
143
        try:
144
            while not self.exit_flag:
145
                time.sleep(5)
146
        except KeyboardInterrupt:
147
            self.stop()
148
149
    def stop(self):
150
        logger.info('Server is shuting down ......')
151
        self.exit_flag = True
152
        self.server_impl.stop(0)
153
        self.tracer.close()
154
        logger.info('Server is closed')
155
156
    def decorate_handler(self, handler):
157
        for key, attr in handler.__dict__.items():
158
            if is_grpc_method(attr):
159
                setattr(handler, key, self.wrap_method_with_errorhandler(attr))
160
        return handler
161