Passed
Push — master ( 0d356b...652bd4 )
by Michael
11:41 queued 05:19
created

WsJsonRpcClient._handle_ws_message()   A

Complexity

Conditions 5

Size

Total Lines 18
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 18
rs 9.2333
c 0
b 0
f 0
cc 5
nop 2
1
import asyncio
2
import json
3
import logging
4
import typing
5
6
import aiohttp
7
from aiohttp import http_websocket, web_ws
8
9
from .base import BaseJsonRpcClient
10
from .. import errors, utils
11
12
13
__all__ = (
14
    'WsJsonRpcClient',
15
)
16
17
logger = logging.getLogger(__name__)
18
19
WSConnectType = typing.Union[aiohttp.ClientWebSocketResponse, web_ws.WebSocketResponse]
20
21
22
class WsJsonRpcClient(BaseJsonRpcClient):
23
    url: typing.Optional[str]
24
    ws_connect: typing.Optional[WSConnectType]
25
    timeout: typing.Optional[int]
26
    ws_connect_kwargs: dict
27
    _pending: typing.Dict[typing.Any, asyncio.Future]
28
    _message_worker: typing.Optional[asyncio.Future] = None
29
    _session_is_outer: bool
30
    _ws_connect_is_outer: bool
31
    _json_request_handler: typing.Optional[typing.Callable] = None
32
    _unprocessed_json_response_handler: typing.Optional[typing.Callable] = None
33
34
    def __init__(self,
35
                 url: typing.Optional[str] = None, *,
36
                 session: typing.Optional[aiohttp.ClientSession] = None,
37
                 ws_connect: typing.Optional[WSConnectType] = None,
38
                 timeout: typing.Optional[int] = 5,
39
                 json_request_handler: typing.Optional[typing.Callable] = None,
40
                 unprocessed_json_response_handler: typing.Optional[typing.Callable] = None,
41
                 **ws_connect_kwargs) -> None:
42
        assert (session is not None) or (url is not None and session is None) or (ws_connect is not None)
43
44
        self.url = url
45
        self.timeout = timeout
46
47
        self.session = session
48
        self._session_is_outer = session is not None
49
50
        self.ws_connect = ws_connect
51
        self.ws_connect_kwargs = ws_connect_kwargs
52
        self._ws_connect_is_outer = ws_connect is not None
53
54
        self._pending = {}
55
        self._json_request_handler = json_request_handler
56
        self._unprocessed_json_response_handler = unprocessed_json_response_handler
57
58
    async def connect(self) -> None:
59
        if not self.session and not self.ws_connect:
60
            self.session = aiohttp.ClientSession(json_serialize=self.json_serialize)
61
62
        if not self.ws_connect:
63
            try:
64
                self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
65
            except Exception:
66
                await self.disconnect()
67
                raise
68
69
        self._message_worker = asyncio.ensure_future(self._handle_ws_messages())
70
71
    async def disconnect(self) -> None:
72
        if self.ws_connect and not self._ws_connect_is_outer:
73
            await self.ws_connect.close()
74
75
        if self.session and not self._session_is_outer:
76
            await self.session.close()
77
78
        if self._message_worker:
79
            await self._message_worker
80
81
    async def send_json(self,
82
                        data: typing.Any, *,
83
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
84
        if without_response:
85
            await self.ws_connect.send_str(self.json_serialize(data))
86
            return None, None
87
88
        request_ids = self._get_ids_from_json(data)
89
        future = asyncio.Future()
90
91
        for request_id in request_ids:
92
            self._pending[request_id] = future
93
94
        await self.ws_connect.send_str(self.json_serialize(data))
95
96
        if not request_ids:
97
            return None, None
98
99
        if self.timeout is not None:
100
            future = asyncio.wait_for(future, timeout=self.timeout)
101
102
        result = await future
103
104
        return result, None
105
106
    def clear_pending(self) -> None:
107
        self._pending.clear()
108
109
    @staticmethod
110
    def _get_ids_from_json(data: typing.Any) -> typing.Optional[list]:
111
        if not data:
112
            return []
113
114
        if isinstance(data, dict) and data.get('id') is not None:
115
            return [data['id']]
116
117
        if isinstance(data, list):
118
            return [
119
                item['id']
120
                for item in data
121
                if isinstance(item, dict) and item.get('id') is not None
122
            ]
123
124
        return []
125
126
    async def _handle_ws_messages(self) -> typing.NoReturn:
127
        async for ws_msg in self.ws_connect:
128
            if ws_msg.type != http_websocket.WSMsgType.TEXT:
129
                continue
130
131
            try:
132
                asyncio.ensure_future(self._handle_single_ws_message(ws_msg))
133
            except asyncio.CancelledError as e:
134
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
135
                self._notify_all_about_error(error)
136
                raise
137
            except Exception:
138
                logger.warning('Can not process WS message.', exc_info=True)
139
140
    async def _handle_single_ws_message(self, ws_msg: aiohttp.WSMessage) -> None:
141
        if ws_msg.type != aiohttp.WSMsgType.text:
142
            return
143
144
        try:
145
            json_response = json.loads(ws_msg.data)
146
        except Exception:
147
            logger.warning('Can not parse json.', exc_info=True)
148
            return
149
150
        if not json_response:
151
            return
152
153
        if isinstance(json_response, dict):
154
            await self._handle_single_json_response(json_response, ws_msg=ws_msg)
155
            return
156
157
        if isinstance(json_response, list):
158
            await self._handle_json_responses(json_response, ws_msg=ws_msg)
159
            return
160
161
        logger.warning('Couldn\'t process the response.', extra={
162
            'json_response': json_response,
163
        })
164
165
    async def _handle_single_json_response(self, json_response: dict, *, ws_msg: aiohttp.WSMessage) -> None:
166
        if 'method' in json_response:
167
            if self._json_request_handler:
168
                await self._json_request_handler(
169
                    ws_connect=self.ws_connect,
170
                    ws_msg=ws_msg,
171
                    json_request=json_response,
172
                )
173
        elif 'id' in json_response and json_response['id'] in self._pending:
174
            self._notify_about_result(json_response['id'], json_response)
175
        elif self._unprocessed_json_response_handler:
176
            self._unprocessed_json_response_handler(
177
                ws_connect=self.ws_connect,
178
                ws_msg=ws_msg,
179
                json_response=json_response,
180
            )
181
182
    async def _handle_json_responses(self, json_responses: list, *, ws_msg: aiohttp.WSMessage) -> None:
183
        if isinstance(json_responses[0], dict) and 'method' in json_responses[0]:
184
            if self._json_request_handler:
185
                await self._json_request_handler(ws_connect=self.ws_connect, ws_msg=ws_msg)
186
        else:
187
            response_ids = self._get_ids_from_json(json_responses)
188
189
            if response_ids:
190
                self._notify_about_results(response_ids, json_responses)
191
            else:
192
                self._unprocessed_json_response_handler(
193
                    ws_connect=self.ws_connect,
194
                    ws_msg=ws_msg,
195
                    json_response=json_responses,
196
                )
197
198
    def _notify_all_about_error(self, error: errors.JsonRpcError) -> None:
199
        for future in self._pending.values():
200
            future.set_exception(error)
201
202
        self.clear_pending()
203
204
    def _notify_about_result(self, response_id: typing.Any, json_response: dict) -> None:
205
        future = self._pending.pop(response_id, None)
206
207
        if future:
208
            future.set_result(json_response)
209
210
    def _notify_about_results(self, response_ids: list, json_response: list) -> None:
211
        is_processed = False
212
213
        for response_id in response_ids:
214
            future = self._pending.pop(response_id, None)
215
216
            if future and not is_processed:
217
                future.set_result(json_response)
218
                is_processed = True
219