Passed
Push — master ( b981ec...396176 )
by Michael
04:53
created

WsJsonRpcClient._handle_ws_messages()   B

Complexity

Conditions 6

Size

Total Lines 21
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 16
dl 0
loc 21
rs 8.6666
c 0
b 0
f 0
cc 6
nop 1
1
import asyncio
2
import json
3
import logging
4
import typing
5
6
from aiohttp import ClientSession, http_websocket, web_ws
7
8
from .base import BaseJsonRpcClient
9
from .. import errors, typedefs, utils
10
11
12
__all__ = (
13
    'WsJsonRpcClient',
14
)
15
16
logger = logging.getLogger(__name__)
17
18
19
class WsJsonRpcClient(BaseJsonRpcClient):
20
    url: typing.Optional[str]
21
    ws_connect: typing.Optional[typedefs.WSConnectType]
22
    ws_connect_kwargs: dict
23
    timeout: typing.Optional[int]
24
    _pending: typing.Dict[typing.Any, asyncio.Future]
25
    _message_worker: typing.Optional[asyncio.Future] = None
26
    _session_is_outer: bool
27
    _ws_connect_is_outer: bool
28
    _json_request_handler: typing.Optional[typing.Callable] = None
29
    _unprocessed_json_response_handler: typing.Optional[typing.Callable] = None
30
    _background_tasks: typing.Set
31
32
    def __init__(self,
33
                 url: typing.Optional[str] = None, *,
34
                 session: typing.Optional[ClientSession] = None,
35
                 ws_connect: typing.Optional[typedefs.WSConnectType] = None,
36
                 timeout: typing.Optional[int] = 5,
37
                 json_request_handler: typing.Optional[typing.Callable] = None,
38
                 unprocessed_json_response_handler: typing.Optional[typing.Callable] = None,
39
                 **ws_connect_kwargs) -> None:
40
        assert (session is not None) or (url is not None and session is None) or (ws_connect is not None)
41
42
        self.url = url
43
        self.timeout = timeout
44
45
        self.session = session
46
        self._session_is_outer = session is not None  # We don't close an outer session.
47
48
        self.ws_connect = ws_connect
49
        self.ws_connect_kwargs = ws_connect_kwargs
50
        self._ws_connect_is_outer = ws_connect is not None  # We don't close an outer WS connection.
51
52
        self._pending = {}
53
        self._json_request_handler = json_request_handler
54
        self._unprocessed_json_response_handler = unprocessed_json_response_handler
55
        self._background_tasks = set()
56
57
    async def connect(self) -> None:
58
        if self.session is None and self.ws_connect is None:
59
            self.session = ClientSession(json_serialize=self.json_serialize)
60
61
        if self.ws_connect is None:
62
            assert self.url is not None and self.session is not None
63
64
            try:
65
                self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
66
            except Exception:
67
                await self.disconnect()
68
                raise
69
70
        self._message_worker = asyncio.create_task(self._handle_ws_messages())
71
72
    async def disconnect(self) -> None:
73
        if self.ws_connect is not None and not self._ws_connect_is_outer:
74
            await self.ws_connect.close()
75
76
        if self.session is not None and not self._session_is_outer:
77
            await self.session.close()
78
79
        if self._message_worker is not None:
80
            await self._message_worker
81
82
    async def send_json(self,
83
                        data: typing.Any, *,
84
                        without_response: bool = False,
85
                        **kwargs) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
86
        assert self.ws_connect is not None
87
88
        if without_response:
89
            await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
90
            return None, None
91
92
        request_ids = self._get_ids_from_json(data)
93
        future: asyncio.Future = asyncio.Future()
94
95
        for request_id in request_ids:
96
            self._pending[request_id] = future
97
98
        await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
99
100
        if not request_ids:
101
            return None, None
102
103
        if self.timeout is not None:
104
            future = asyncio.wait_for(future, timeout=self.timeout)  # type: ignore
105
106
        result = await future
107
108
        return result, None
109
110
    def clear_pending(self) -> None:
111
        self._pending.clear()
112
113
    @staticmethod
114
    def _get_ids_from_json(data: typing.Any) -> typing.Tuple[typedefs.JsonRpcIdType, ...]:
115
        if not data:
116
            return ()
117
118
        if isinstance(data, typing.Mapping) and data.get('id') is not None:
119
            return (
120
                data['id'],
121
            )
122
123
        if isinstance(data, typing.Sequence):
124
            return tuple(
125
                item['id']
126
                for item in data
127
                if isinstance(item, typing.Mapping) and item.get('id') is not None
128
            )
129
130
        return ()
131
132
    async def _handle_ws_messages(self) -> None:
133
        assert self.ws_connect is not None
134
135
        ws_msg: http_websocket.WSMessage
136
137
        async for ws_msg in self.ws_connect:
138
            if ws_msg.type != http_websocket.WSMsgType.TEXT:
139
                continue
140
141
            try:
142
                task = asyncio.create_task(self._handle_single_ws_message(ws_msg))
143
            except asyncio.CancelledError as e:
144
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
145
                self._notify_all_about_error(error)
146
                raise
147
            except Exception:
148
                logger.warning('Can not process WS message.', exc_info=True)
149
            else:
150
                # To avoid a task disappearing mid execution:
151
                self._background_tasks.add(task)
152
                task.add_done_callback(self._background_tasks.discard)
153
154
    async def _handle_single_ws_message(self, ws_msg: http_websocket.WSMessage) -> None:
155
        if ws_msg.type != http_websocket.WSMsgType.text:
156
            return
157
158
        try:
159
            json_response = json.loads(ws_msg.data)
160
        except Exception:
161
            logger.warning('Cann\'t parse json.', exc_info=True)
162
            return
163
164
        if not json_response:
165
            return
166
167
        if isinstance(json_response, typing.Mapping):
168
            await self._handle_single_json_response(json_response, ws_msg=ws_msg)
169
            return
170
171
        if isinstance(json_response, typing.Sequence):
172
            await self._handle_json_responses(json_response, ws_msg=ws_msg)
173
            return
174
175
        logger.warning('Couldn\'t process the response.', extra={
176
            'json_response': json_response,
177
        })
178
179
    async def _handle_single_json_response(self, json_response: typing.Mapping, *, ws_msg: web_ws.WSMessage) -> None:
180
        if 'method' in json_response:
181
            if self._json_request_handler is not None:
182
                await self._json_request_handler(
183
                    ws_connect=self.ws_connect,
184
                    ws_msg=ws_msg,
185
                    json_request=json_response,
186
                )
187
        elif 'id' in json_response and json_response['id'] in self._pending:
188
            self._notify_about_result(json_response['id'], json_response)
189
        elif self._unprocessed_json_response_handler is not None:
190
            self._unprocessed_json_response_handler(
191
                ws_connect=self.ws_connect,
192
                ws_msg=ws_msg,
193
                json_response=json_response,
194
            )
195
196
    async def _handle_json_responses(self, json_responses: typing.Sequence, *, ws_msg: web_ws.WSMessage) -> None:
197
        if isinstance(json_responses[0], typing.Mapping) and 'method' in json_responses[0]:
198
            if self._json_request_handler is not None:
199
                await self._json_request_handler(ws_connect=self.ws_connect, ws_msg=ws_msg)
200
        else:
201
            response_ids = self._get_ids_from_json(json_responses)
202
203
            if response_ids:
204
                self._notify_about_results(response_ids, json_responses)
205
            elif self._unprocessed_json_response_handler is not None:
206
                self._unprocessed_json_response_handler(
207
                    ws_connect=self.ws_connect,
208
                    ws_msg=ws_msg,
209
                    json_response=json_responses,
210
                )
211
212
    def _notify_all_about_error(self, error: errors.JsonRpcError) -> None:
213
        for future in self._pending.values():
214
            future.set_exception(error)
215
216
        self.clear_pending()
217
218
    def _notify_about_result(self, response_id: typedefs.JsonRpcIdType, json_response: typing.Mapping) -> None:
219
        future = self._pending.pop(response_id, None)
220
221
        if future is not None:
222
            future.set_result(json_response)
223
224
    def _notify_about_results(self,
225
                              response_ids: typing.Sequence[typedefs.JsonRpcIdType],
226
                              json_response: typing.Sequence) -> None:
227
        is_processed = False
228
229
        for response_id in response_ids:
230
            future = self._pending.pop(response_id, None)
231
232
            if future is not None and not is_processed:
233
                # We suppose that a batch result has the same ids that we sent.
234
                # And this ids have the same future.
235
236
                future.set_result(json_response)
237
                is_processed = True
238