Test Failed
Push — master ( c58336...8dd0a6 )
by Michael
03:57
created

aiohttp_rpc.server.BaseJsonRpcServer._process_single_request()   A

Complexity

Conditions 2

Size

Total Lines 21
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 16
nop 2
dl 0
loc 21
rs 9.6
c 0
b 0
f 0
1
import abc
2
import asyncio
3
import json
4
import typing
5
6
import aiohttp
7
from aiohttp import web, web_ws
8
9
from . import constants, errors, middlewares as rpc_middleware, protocol, utils
10
11
12
__all__ = (
13
    'BaseJsonRpcServer',
14
    'JsonRpcServer',
15
    'WsJsonRpcServer',
16
    'rpc_server',
17
)
18
19
20
class BaseJsonRpcServer(abc.ABC):
21
    methods: typing.Dict[str, protocol.JsonRpcMethod]
22
    middlewares: typing.Tuple[typing.Type[rpc_middleware.BaseJsonRpcMiddleware], ...]
23
    json_serialize: typing.Callable
24
    _middleware_chain: typing.Callable
25
26
    def __init__(self, *,
27
                 json_serialize: typing.Callable = utils.json_serialize,
28
                 middlewares: typing.Iterable = (),
29
                 methods: typing.Optional[typing.Dict[str, protocol.JsonRpcMethod]] = None) -> None:
30
        if methods is None:
31
            methods = {'get_methods': protocol.JsonRpcMethod('', self.get_methods)}
32
33
        self.methods = methods
34
35
        self.middlewares = tuple(middlewares)
36
        self.load_middlewares()
37
38
        self.json_serialize = json_serialize
39
40
    def load_middlewares(self):
41
        self._middleware_chain = self._process_single_request
42
43
        for middleware_class in reversed(self.middlewares):
44
            if isinstance(middleware_class, (list, tuple,)):
45
                middleware_class, kwargs = middleware_class
46
                self._middleware_chain = middleware_class(server=self, get_response=self._middleware_chain, **kwargs)
47
                continue
48
49
            self._middleware_chain = middleware_class(server=self, get_response=self._middleware_chain)
50
51
    def add_method(self,
52
                   method: typing.Union[protocol.JsonRpcMethod, tuple, list, typing.Callable], *,
53
                   replace: bool = False) -> protocol.JsonRpcMethod:
54
        if not isinstance(method, protocol.JsonRpcMethod):
55
            if callable(method):
56
                method = protocol.JsonRpcMethod('', method)
57
            else:
58
                method = protocol.JsonRpcMethod(*method)
59
60
        if not replace and method.name in self.methods:
61
            raise errors.InvalidParams(f'Method {method.name} has already been added.')
62
63
        self.methods[method.name] = method
64
65
        return method
66
67
    def add_methods(self,
68
                    methods: typing.Iterable[typing.Union[protocol.JsonRpcMethod, tuple, list, typing.Callable]], *,
69
                    replace: bool = False) -> typing.List[protocol.JsonRpcMethod]:
70
        return [
71
            self.add_method(method, replace=replace)
72
            for method in methods
73
        ]
74
75
    async def call(self,
76
                   method: str, *,
77
                   args: typing.Optional[list] = None,
78
                   kwargs: typing.Optional[dict] = None,
79
                   extra_args: typing.Optional[dict] = None) -> typing.Any:
80
        if args is None:
81
            args = []
82
83
        if kwargs is None:
84
            kwargs = {}
85
86
        if method not in self.methods:
87
            raise errors.MethodNotFound
88
89
        return await self.methods[method](args=args, kwargs=kwargs, extra_args=extra_args)
90
91
    def get_methods(self) -> dict:
92
        return {
93
            name: {
94
                'doc': method.func.__doc__,
95
                'args': method.supported_args,
96
                'kwargs': method.supported_kwargs,
97
            }
98
            for name, method in self.methods.items()
99
        }
100
101
    async def _process_input_data(self,
102
                                  data: typing.Union[dict, list], *,
103
                                  http_request: typing.Optional[web.Request] = None) -> typing.Any:
104
        if isinstance(data, list):
105
            json_responses = await asyncio.gather(*(
106
                self._process_single_json_request(raw_rcp_request, http_request=http_request)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable raw_rcp_request does not seem to be defined.
Loading history...
107
                for raw_rcp_request in data
108
            ), return_exceptions=True)
109
110
            for i, json_response in enumerate(json_responses):
111
                if isinstance(json_response, errors.JsonRpcError):
112
                    json_responses[i] = protocol.JsonRpcResponse(error=json_response)
113
                elif isinstance(json_response, Exception):
114
                    json_responses[i] = protocol.JsonRpcResponse(
115
                        error=errors.JsonRpcError(utils.get_exc_message(json_response)),
116
                    )
117
118
            return json_responses
119
120
        if isinstance(data, dict):
121
            return await self._process_single_json_request(data, http_request=http_request)
122
123
        response = protocol.JsonRpcResponse(error=errors.ParseError('Data must be a dict or an list.'))
124
        return response.to_dict()
125
126
    async def _process_single_json_request(self,
127
                                           json_request: dict, *,
128
                                           http_request: typing.Optional[web.Request] = None) -> dict:
129
        if not isinstance(json_request, dict):
130
            raise errors.ParseError('Data must be a dict or an list.')
131
132
        msg_id = json_request.get('id')
133
134
        try:
135
            request = protocol.JsonRpcRequest.from_dict(json_request, context={'http_request': http_request})
136
        except errors.JsonRpcError as e:
137
            response = protocol.JsonRpcResponse(msg_id=msg_id, error=e)
138
            return response.to_dict()
139
140
        response = await self._middleware_chain(request)
141
        return response.to_dict()
142
143
    async def _process_single_request(self, request: protocol.JsonRpcRequest) -> protocol.JsonRpcResponse:
144
        result, error = constants.NOTHING, constants.NOTHING
145
146
        try:
147
            result = await self.call(
148
                request.method,
149
                args=request.args,
150
                kwargs=request.kwargs,
151
                extra_args=request.extra_args,
152
            )
153
        except errors.JsonRpcError as e:
154
            error = e
155
156
        response = protocol.JsonRpcResponse(
157
            msg_id=request.msg_id,
158
            jsonrpc=request.jsonrpc,
159
            result=result,
160
            error=error,
161
        )
162
163
        return response
164
165
166
class JsonRpcServer(BaseJsonRpcServer):
167
    async def handle_http_request(self, http_request: web.Request) -> web.Response:
168
        if http_request.method != 'POST':
169
            return web.HTTPMethodNotAllowed(method=http_request.method, allowed_methods=('POST',))
170
171
        try:
172
            input_data = await http_request.json()
173
        except json.JSONDecodeError as e:
174
            response = protocol.JsonRpcResponse(error=errors.ParseError(utils.get_exc_message(e)))
175
            return web.json_response(response.to_dict(), dumps=self.json_serialize)
176
177
        output_data = await self._process_input_data(input_data, http_request=http_request)
178
179
        return web.json_response(output_data, dumps=self.json_serialize)
180
181
182
class WsJsonRpcServer(BaseJsonRpcServer):
183
    async def handle_http_request(self, http_request: web.Request) -> web.StreamResponse:
184
        if http_request.method == 'GET' and http_request.headers.get('upgrade', '').lower() == 'websocket':
185
            return await self.handle_websocket_request(http_request)
186
        else:
187
            return web.HTTPMethodNotAllowed(method=http_request.method, allowed_methods=('POST',))
188
189
    async def handle_websocket_request(self, http_request: web.Request) -> web_ws.WebSocketResponse:
190
        http_request.msg_id = 0
191
        http_request.pending = {}
192
193
        ws = web_ws.WebSocketResponse()
194
        await ws.prepare(http_request)
195
        http_request.ws = ws
196
197
        while not ws.closed:
198
            ws_msg = await ws.receive()
199
200
            if ws_msg.type != aiohttp.WSMsgType.TEXT:
201
                continue
202
203
            await self._handle_ws_msg(http_request, ws_msg)
204
205
        return ws
206
207
    async def _handle_ws_msg(self, http_request: web.Request, ws_msg: web_ws.WSMessage) -> None:
208
        input_data = json.loads(ws_msg.data)
209
        output_data = await self._process_input_data(input_data, http_request=http_request)
210
211
        if http_request.ws._writer.transport.is_closing():
212
            self.clients.remove(http_request)
213
            await http_request.ws.close()
214
215
        await http_request.ws.send_str(self.json_serialize(output_data))
216
217
218
rpc_server = JsonRpcServer(
219
    middlewares=rpc_middleware.DEFAULT_MIDDLEWARES,
220
)
221