|
1
|
|
|
import asyncio |
|
2
|
|
|
import json |
|
3
|
|
|
import typing |
|
4
|
|
|
import weakref |
|
5
|
|
|
|
|
6
|
|
|
from aiohttp import http_websocket, web, web_ws |
|
7
|
|
|
|
|
8
|
|
|
from .base import BaseJsonRpcServer |
|
9
|
|
|
from .. import errors, protocol, utils |
|
10
|
|
|
|
|
11
|
|
|
|
|
12
|
|
|
__all__ = ( |
|
13
|
|
|
'WsJsonRpcServer', |
|
14
|
|
|
) |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
class WsJsonRpcServer(BaseJsonRpcServer): |
|
18
|
|
|
rcp_websockets: weakref.WeakSet |
|
19
|
|
|
_json_response_handler: typing.Optional[typing.Callable] = None |
|
20
|
|
|
_background_tasks: typing.Set |
|
21
|
|
|
|
|
22
|
|
|
def __init__(self, |
|
23
|
|
|
*args, |
|
24
|
|
|
json_response_handler: typing.Optional[typing.Callable] = None, |
|
25
|
|
|
**kwargs) -> None: |
|
26
|
|
|
super().__init__(*args, **kwargs) |
|
27
|
|
|
|
|
28
|
|
|
self.rcp_websockets = weakref.WeakSet() |
|
29
|
|
|
self._json_response_handler = json_response_handler |
|
30
|
|
|
self._background_tasks = set() |
|
31
|
|
|
|
|
32
|
|
|
async def handle_http_request(self, http_request: web.Request) -> web.StreamResponse: |
|
33
|
|
|
if http_request.method != 'GET' or http_request.headers.get('upgrade', '').lower() != 'websocket': |
|
34
|
|
|
raise web.HTTPMethodNotAllowed(method=http_request.method, allowed_methods=('GET',)) |
|
35
|
|
|
|
|
36
|
|
|
return await self._handle_ws_request(http_request) |
|
37
|
|
|
|
|
38
|
|
|
async def on_shutdown(self, app: web.Application) -> None: |
|
39
|
|
|
# https://docs.aiohttp.org/en/stable/web_advanced.html#graceful-shutdown |
|
40
|
|
|
|
|
41
|
|
|
for ws in self.rcp_websockets: |
|
42
|
|
|
await ws.close(code=http_websocket.WSCloseCode.GOING_AWAY, message='Server shutdown') |
|
43
|
|
|
|
|
44
|
|
|
self.rcp_websockets.clear() |
|
45
|
|
|
|
|
46
|
|
|
async def _handle_ws_request(self, http_request: web.Request) -> web_ws.WebSocketResponse: |
|
47
|
|
|
from aiohttp_rpc import WsJsonRpcClient |
|
48
|
|
|
|
|
49
|
|
|
ws_connect = web_ws.WebSocketResponse() |
|
50
|
|
|
await ws_connect.prepare(http_request) |
|
51
|
|
|
|
|
52
|
|
|
self.rcp_websockets.add(ws_connect) |
|
53
|
|
|
|
|
54
|
|
|
ws_rpc_client = WsJsonRpcClient(ws_connect=ws_connect) |
|
55
|
|
|
|
|
56
|
|
|
ws_msg: http_websocket.WSMessage |
|
57
|
|
|
|
|
58
|
|
|
async for ws_msg in ws_connect: |
|
59
|
|
|
if ws_msg.type != http_websocket.WSMsgType.TEXT: |
|
60
|
|
|
continue |
|
61
|
|
|
|
|
62
|
|
|
coro = self._handle_ws_message( |
|
63
|
|
|
ws_msg=ws_msg, |
|
64
|
|
|
ws_connect=ws_connect, |
|
65
|
|
|
context={ |
|
66
|
|
|
'http_request': http_request, |
|
67
|
|
|
'ws_connect': ws_connect, |
|
68
|
|
|
'ws_rpc_client': ws_rpc_client, |
|
69
|
|
|
}, |
|
70
|
|
|
) |
|
71
|
|
|
|
|
72
|
|
|
task = asyncio.create_task(coro) |
|
73
|
|
|
|
|
74
|
|
|
# To avoid a task disappearing mid execution: |
|
75
|
|
|
self._background_tasks.add(task) |
|
76
|
|
|
task.add_done_callback(self._background_tasks.discard) |
|
77
|
|
|
|
|
78
|
|
|
return ws_connect |
|
79
|
|
|
|
|
80
|
|
|
async def _handle_ws_message(self, |
|
81
|
|
|
ws_msg: web_ws.WSMessage, *, |
|
82
|
|
|
ws_connect: web_ws.WebSocketResponse, |
|
83
|
|
|
context: dict) -> None: |
|
84
|
|
|
json_response: typing.Optional[typing.Union[typing.Mapping, typing.Sequence[typing.Mapping]]] |
|
85
|
|
|
|
|
86
|
|
|
try: |
|
87
|
|
|
input_data = json.loads(ws_msg.data) |
|
88
|
|
|
except json.JSONDecodeError as e: |
|
89
|
|
|
response = protocol.JsonRpcResponse(error=errors.ParseError(utils.get_exc_message(e))) |
|
90
|
|
|
json_response = response.dump() |
|
91
|
|
|
else: |
|
92
|
|
|
json_response = await self._process_input_data(input_data, context=context) |
|
93
|
|
|
|
|
94
|
|
|
if json_response is None: |
|
95
|
|
|
return |
|
96
|
|
|
|
|
97
|
|
|
if ws_connect.closed: |
|
98
|
|
|
raise errors.ServerError('WS is closed.') |
|
99
|
|
|
|
|
100
|
|
|
await ws_connect.send_str(self.json_serialize(json_response)) |
|
101
|
|
|
|