aiohttp_rpc.server.websocket   A
last analyzed

Complexity

Total Complexity 14

Size/Duplication

Total Lines 101
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 14
eloc 67
dl 0
loc 101
rs 10
c 0
b 0
f 0

5 Methods

Rating   Name   Duplication   Size   Complexity  
A WsJsonRpcServer.on_shutdown() 0 7 2
A WsJsonRpcServer.__init__() 0 9 1
A WsJsonRpcServer._handle_ws_message() 0 21 5
A WsJsonRpcServer._handle_ws_request() 0 33 3
A WsJsonRpcServer.handle_http_request() 0 5 3
1
import asyncio
2
import json
3
import typing
4
import weakref
5
6
from aiohttp import http_websocket, web, web_ws
7
8
from .base import BaseJsonRpcServer
9
from .. import errors, protocol, utils
10
11
12
__all__ = (
13
    'WsJsonRpcServer',
14
)
15
16
17
class WsJsonRpcServer(BaseJsonRpcServer):
18
    rcp_websockets: weakref.WeakSet
19
    _json_response_handler: typing.Optional[typing.Callable] = None
20
    _background_tasks: typing.Set
21
22
    def __init__(self,
23
                 *args,
24
                 json_response_handler: typing.Optional[typing.Callable] = None,
25
                 **kwargs) -> None:
26
        super().__init__(*args, **kwargs)
27
28
        self.rcp_websockets = weakref.WeakSet()
29
        self._json_response_handler = json_response_handler
30
        self._background_tasks = set()
31
32
    async def handle_http_request(self, http_request: web.Request) -> web.StreamResponse:
33
        if http_request.method != 'GET' or http_request.headers.get('upgrade', '').lower() != 'websocket':
34
            raise web.HTTPMethodNotAllowed(method=http_request.method, allowed_methods=('GET',))
35
36
        return await self._handle_ws_request(http_request)
37
38
    async def on_shutdown(self, app: web.Application) -> None:
39
        # https://docs.aiohttp.org/en/stable/web_advanced.html#graceful-shutdown
40
41
        for ws in self.rcp_websockets:
42
            await ws.close(code=http_websocket.WSCloseCode.GOING_AWAY, message='Server shutdown')
43
44
        self.rcp_websockets.clear()
45
46
    async def _handle_ws_request(self, http_request: web.Request) -> web_ws.WebSocketResponse:
47
        from aiohttp_rpc import WsJsonRpcClient
48
49
        ws_connect = web_ws.WebSocketResponse()
50
        await ws_connect.prepare(http_request)
51
52
        self.rcp_websockets.add(ws_connect)
53
54
        ws_rpc_client = WsJsonRpcClient(ws_connect=ws_connect)
55
56
        ws_msg: http_websocket.WSMessage
57
58
        async for ws_msg in ws_connect:
59
            if ws_msg.type != http_websocket.WSMsgType.TEXT:
60
                continue
61
62
            coro = self._handle_ws_message(
63
                ws_msg=ws_msg,
64
                ws_connect=ws_connect,
65
                context={
66
                    'http_request': http_request,
67
                    'ws_connect': ws_connect,
68
                    'ws_rpc_client': ws_rpc_client,
69
                },
70
            )
71
72
            task = asyncio.create_task(coro)
73
74
            # To avoid a task disappearing mid execution:
75
            self._background_tasks.add(task)
76
            task.add_done_callback(self._background_tasks.discard)
77
78
        return ws_connect
79
80
    async def _handle_ws_message(self,
81
                                 ws_msg: web_ws.WSMessage, *,
82
                                 ws_connect: web_ws.WebSocketResponse,
83
                                 context: dict) -> None:
84
        json_response: typing.Optional[typing.Union[typing.Mapping, typing.Sequence[typing.Mapping]]]
85
86
        try:
87
            input_data = json.loads(ws_msg.data)
88
        except json.JSONDecodeError as e:
89
            response = protocol.JsonRpcResponse(error=errors.ParseError(utils.get_exc_message(e)))
90
            json_response = response.dump()
91
        else:
92
            json_response = await self._process_input_data(input_data, context=context)
93
94
        if json_response is None:
95
            return
96
97
        if ws_connect.closed:
98
            raise errors.ServerError('WS is closed.')
99
100
        await ws_connect.send_str(self.json_serialize(json_response))
101