Passed
Pull Request — master (#13)
by Michael
03:50
created

WsJsonRpcClient._get_ids_from_json()   A

Complexity

Conditions 5

Size

Total Lines 18
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 18
rs 9.2833
c 0
b 0
f 0
cc 5
nop 1
1
import asyncio
2
import json
3
import logging
4
import typing
5
6
from aiohttp import ClientSession, http_websocket, web_ws
7
8
from .base import BaseJsonRpcClient
9
from .. import errors, typedefs, utils
10
11
12
__all__ = (
13
    'WsJsonRpcClient',
14
)
15
16
logger = logging.getLogger(__name__)
17
18
19
class WsJsonRpcClient(BaseJsonRpcClient):
20
    url: typing.Optional[str]
21
    ws_connect: typing.Optional[typedefs.WSConnectType]
22
    ws_connect_kwargs: dict
23
    timeout: typing.Optional[int]
24
    _connection_check_interval: typing.Optional[int]
25
    _pending: typing.Dict[typing.Any, asyncio.Future]
26
    _message_worker: typing.Optional[asyncio.Future] = None
27
    _check_worker: typing.Optional[asyncio.Future] = None
28
    _session_is_outer: bool
29
    _ws_connect_is_outer: bool
30
    _json_request_handler: typing.Optional[typing.Callable] = None
31
    _unprocessed_json_response_handler: typing.Optional[typing.Callable] = None
32
    _background_tasks: typing.Set
33
34
    def __init__(self,
35
                 url: typing.Optional[str] = None, *,
36
                 session: typing.Optional[ClientSession] = None,
37
                 ws_connect: typing.Optional[typedefs.WSConnectType] = None,
38
                 timeout: typing.Optional[int] = 5,
39
                 connection_check_interval: typing.Optional[int] = 1,
40
                 json_request_handler: typing.Optional[typing.Callable] = None,
41
                 unprocessed_json_response_handler: typing.Optional[typing.Callable] = None,
42
                 **ws_connect_kwargs) -> None:
43
        assert (session is not None) or (url is not None and session is None) or (ws_connect is not None)
44
45
        self.url = url
46
        self.timeout = timeout
47
        self._connection_check_interval = connection_check_interval
48
49
        self.session = session
50
        self._session_is_outer = session is not None  # We don't close an outer session.
51
52
        self.ws_connect = ws_connect
53
        self.ws_connect_kwargs = ws_connect_kwargs
54
        self._ws_connect_is_outer = ws_connect is not None  # We don't close an outer WS connection.
55
56
        self._pending = {}
57
        self._json_request_handler = json_request_handler
58
        self._unprocessed_json_response_handler = unprocessed_json_response_handler
59
        self._background_tasks = set()
60
61
    async def connect(self) -> None:
62
        if self.session is None and self.ws_connect is None:
63
            self.session = ClientSession(json_serialize=self.json_serialize)
64
65
        if self.ws_connect is None:
66
            assert self.url is not None and self.session is not None
67
68
            try:
69
                self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
70
            except Exception:
71
                await self.disconnect()
72
                raise
73
74
        self._message_worker = asyncio.create_task(self._handle_ws_messages())
75
76
        if self._connection_check_interval is not None:
77
            self._check_worker = asyncio.create_task(self._check_ws_connection())
78
79
    async def disconnect(self) -> None:
80
        if self.ws_connect is not None and not self._ws_connect_is_outer:
81
            await self.ws_connect.close()
82
83
        if self.session is not None and not self._session_is_outer:
84
            await self.session.close()
85
86
        if self._message_worker is not None:
87
            await self._message_worker
88
89
        if self._check_worker is not None:
90
            await self._check_worker
91
92
    async def send_json(self,
93
                        data: typing.Any, *,
94
                        without_response: bool = False,
95
                        **kwargs) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
96
        assert self.ws_connect is not None
97
98
        if without_response:
99
            try:
100
                await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
101
            except ConnectionResetError as e:
102
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
103
                self._notify_all_about_error(error)
104
                raise error
105
106
            return None, None
107
108
        request_ids = self._get_ids_from_json(data)
109
        future: asyncio.Future = asyncio.Future()
110
111
        for request_id in request_ids:
112
            self._pending[request_id] = future
113
114
        try:
115
            await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
116
        except ConnectionResetError as e:
117
            error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
118
            self._notify_all_about_error(error)
119
            raise error
120
121
        if not request_ids:
122
            return None, None
123
124
        if self.timeout is not None:
125
            future = asyncio.wait_for(future, timeout=self.timeout)  # type: ignore
126
127
        result = await future
128
129
        return result, None
130
131
    @staticmethod
132
    def _get_ids_from_json(data: typing.Any) -> typing.Tuple[typedefs.JsonRpcIdType, ...]:
133
        if not data:
134
            return ()
135
136
        if isinstance(data, typing.Mapping) and data.get('id') is not None:
137
            return (
138
                data['id'],
139
            )
140
141
        if isinstance(data, typing.Sequence):
142
            return tuple(
143
                item['id']
144
                for item in data
145
                if isinstance(item, typing.Mapping) and item.get('id') is not None
146
            )
147
148
        return ()
149
150
    async def _handle_ws_messages(self) -> None:
151
        assert self.ws_connect is not None
152
153
        ws_msg_types_to_stop = (
154
            http_websocket.WSMsgType.CLOSE,
155
            http_websocket.WSMsgType.CLOSING,
156
            http_websocket.WSMsgType.CLOSED,
157
            http_websocket.WSMsgType.ERROR,
158
        )
159
160
        ws_msg: http_websocket.WSMessage
161
162
        async for ws_msg in self.ws_connect:
163
            if ws_msg.type == http_websocket.WSMsgType.TEXT:
164
                try:
165
                    task = asyncio.create_task(self._handle_single_ws_message(ws_msg))
166
                except asyncio.CancelledError as e:
167
                    error = errors.InternalError(utils.get_exc_message(e)).with_traceback()
168
                    self._notify_all_about_error(error)
169
                    break
170
                except Exception:
171
                    logger.warning('Can\'t process WS message.', exc_info=True)
172
                else:
173
                    # To avoid a task disappearing mid execution:
174
                    self._background_tasks.add(task)
175
                    task.add_done_callback(self._background_tasks.discard)
176
            elif ws_msg.type in ws_msg_types_to_stop:
177
                error = errors.ServerError('Connection is closed')
178
                self._notify_all_about_error(error)
179
                break
180
181
    async def _check_ws_connection(self) -> None:
182
        assert self.ws_connect is not None
183
184
        while True:
185
            if self.ws_connect.closed:
186
                error = errors.ServerError('Connection is closed')
187
                self._notify_all_about_error(error)
188
                break
189
190
            await asyncio.sleep(self._connection_check_interval)
191
192
    async def _handle_single_ws_message(self, ws_msg: http_websocket.WSMessage) -> None:
193
        if ws_msg.type != http_websocket.WSMsgType.text:
194
            return
195
196
        try:
197
            json_response = json.loads(ws_msg.data)
198
        except Exception:
199
            logger.warning('Can\'t parse json.', exc_info=True)
200
            return
201
202
        if not json_response:
203
            return
204
205
        if isinstance(json_response, typing.Mapping):
206
            await self._handle_single_json_response(json_response, ws_msg=ws_msg)
207
            return
208
209
        if isinstance(json_response, typing.Sequence):
210
            await self._handle_json_responses(json_response, ws_msg=ws_msg)
211
            return
212
213
        logger.warning('Couldn\'t process the response.', extra={
214
            'json_response': json_response,
215
        })
216
217
    async def _handle_single_json_response(self, json_response: typing.Mapping, *, ws_msg: web_ws.WSMessage) -> None:
218
        if 'method' in json_response:
219
            if self._json_request_handler is not None:
220
                await self._json_request_handler(
221
                    ws_connect=self.ws_connect,
222
                    ws_msg=ws_msg,
223
                    json_request=json_response,
224
                )
225
        elif 'id' in json_response and json_response['id'] in self._pending:
226
            self._notify_about_result(json_response['id'], json_response)
227
        elif self._unprocessed_json_response_handler is not None:
228
            self._unprocessed_json_response_handler(
229
                ws_connect=self.ws_connect,
230
                ws_msg=ws_msg,
231
                json_response=json_response,
232
            )
233
234
    async def _handle_json_responses(self, json_responses: typing.Sequence, *, ws_msg: web_ws.WSMessage) -> None:
235
        if isinstance(json_responses[0], typing.Mapping) and 'method' in json_responses[0]:
236
            if self._json_request_handler is not None:
237
                await self._json_request_handler(ws_connect=self.ws_connect, ws_msg=ws_msg)
238
        else:
239
            response_ids = self._get_ids_from_json(json_responses)
240
241
            if response_ids:
242
                self._notify_about_results(response_ids, json_responses)
243
            elif self._unprocessed_json_response_handler is not None:
244
                self._unprocessed_json_response_handler(
245
                    ws_connect=self.ws_connect,
246
                    ws_msg=ws_msg,
247
                    json_response=json_responses,
248
                )
249
250
    def _notify_all_about_error(self, error: Exception) -> None:
251
        for future in self._pending.values():
252
            try:
253
                future.set_exception(error)
254
            except asyncio.InvalidStateError:
255
                pass
256
257
        self._pending.clear()
258
259
    def _notify_about_result(self, response_id: typedefs.JsonRpcIdType, json_response: typing.Mapping) -> None:
260
        future = self._pending.pop(response_id, None)
261
262
        if future is not None:
263
            future.set_result(json_response)
264
265
    def _notify_about_results(self,
266
                              response_ids: typing.Sequence[typedefs.JsonRpcIdType],
267
                              json_response: typing.Sequence) -> None:
268
        is_processed = False
269
270
        for response_id in response_ids:
271
            future = self._pending.pop(response_id, None)
272
273
            if future is not None and not is_processed:
274
                # We suppose that a batch result has the same ids that we sent.
275
                # And these ids have the same future.
276
277
                future.set_result(json_response)
278
                is_processed = True
279