Completed
Push — master ( 3b3d15...f91bc9 )
by Michael
14s queued 12s
created

WsJsonRpcClient._check_ws_connection()   A

Complexity

Conditions 4

Size

Total Lines 13
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 13
rs 9.85
c 0
b 0
f 0
cc 4
nop 1
1
import asyncio
2
import json
3
import logging
4
import typing
5
6
from aiohttp import ClientSession, http_websocket, web_ws
7
8
from .base import BaseJsonRpcClient
9
from .. import errors, typedefs, utils
10
11
12
__all__ = (
13
    'WsJsonRpcClient',
14
)
15
16
logger = logging.getLogger(__name__)
17
18
19
class WsJsonRpcClient(BaseJsonRpcClient):
20
    url: typing.Optional[str]
21
    ws_connect: typing.Optional[typedefs.WSConnectType]
22
    ws_connect_kwargs: dict
23
    _timeout: typing.Optional[float]
24
    _timeout_for_data_receiving: typing.Optional[float]
25
    _connection_check_interval: typing.Optional[float]
26
    _pending: typing.Dict[typing.Any, asyncio.Future]
27
    _message_worker: typing.Optional[asyncio.Future] = None
28
    _check_worker: typing.Optional[asyncio.Future] = None
29
    _session_is_outer: bool
30
    _ws_connect_is_outer: bool
31
    _json_request_handler: typing.Optional[typing.Callable] = None
32
    _unprocessed_json_response_handler: typing.Optional[typing.Callable] = None
33
    _background_tasks: typing.Set
34
    _is_closed: bool = True
35
36
    def __init__(self,
37
                 url: typing.Optional[str] = None, *,
38
                 session: typing.Optional[ClientSession] = None,
39
                 ws_connect: typing.Optional[typedefs.WSConnectType] = None,
40
                 timeout: typing.Optional[float] = 60,
41
                 timeout_for_data_receiving: typing.Optional[float] = None,
42
                 connection_check_interval: typing.Optional[float] = 5,
43
                 json_request_handler: typing.Optional[typing.Callable] = None,
44
                 unprocessed_json_response_handler: typing.Optional[typing.Callable] = None,
45
                 **ws_connect_kwargs) -> None:
46
        assert (session is not None) or (url is not None and session is None) or (ws_connect is not None)
47
48
        self.url = url
49
        self._timeout = timeout
50
        self._timeout_for_data_receiving = timeout_for_data_receiving
51
        self._connection_check_interval = connection_check_interval
52
53
        self.session = session
54
        self._session_is_outer = session is not None  # We don't close an outer session.
55
56
        self.ws_connect = ws_connect
57
        self.ws_connect_kwargs = ws_connect_kwargs
58
        self._ws_connect_is_outer = ws_connect is not None  # We don't close an outer WS connection.
59
60
        self._pending = {}
61
        self._json_request_handler = json_request_handler
62
        self._unprocessed_json_response_handler = unprocessed_json_response_handler
63
        self._background_tasks = set()
64
65
    async def connect(self) -> None:
66
        self._is_closed = False
67
68
        if self.session is None and self.ws_connect is None:
69
            self.session = ClientSession(json_serialize=self.json_serialize)
70
71
        if self.ws_connect is None:
72
            assert self.url is not None and self.session is not None
73
74
            try:
75
                self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
76
            except Exception:
77
                await self.disconnect()
78
                raise
79
80
        self._message_worker = asyncio.create_task(self._handle_ws_messages())
81
82
        if self._connection_check_interval is not None:
83
            self._check_worker = asyncio.create_task(self._check_ws_connection())
84
85
    async def disconnect(self) -> None:
86
        self._is_closed = True
87
88
        if self.ws_connect is not None and not self._ws_connect_is_outer:
89
            await self.ws_connect.close()
90
91
        if self.session is not None and not self._session_is_outer:
92
            await self.session.close()
93
94
        if self._message_worker is not None:
95
            if self._ws_connect_is_outer:
96
                await asyncio.wait_for(self._message_worker, timeout=60)
97
            else:
98
                await self._message_worker
99
100
        if self._check_worker is not None:
101
            self._check_worker.cancel()
102
            await self._check_worker
103
104
    async def send_json(self,
105
                        data: typing.Any, *,
106
                        without_response: bool = False,
107
                        **kwargs) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
108
        assert self.ws_connect is not None
109
110
        if without_response:
111
            try:
112
                await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
113
            except ConnectionResetError as e:
114
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
115
                self._notify_all_about_error(error)
116
                raise error
117
118
            return None, None
119
120
        request_ids = self._get_ids_from_json(data)
121
        future: asyncio.Future = asyncio.Future()
122
123
        for request_id in request_ids:
124
            self._pending[request_id] = future
125
126
        try:
127
            await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
128
        except ConnectionResetError as e:
129
            error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
130
            self._notify_all_about_error(error)
131
            raise error
132
133
        if not request_ids:
134
            return None, None
135
136
        if self._timeout is not None:
137
            future = asyncio.wait_for(future, timeout=self._timeout)  # type: ignore
138
139
        result = await future
140
141
        return result, None
142
143
    @staticmethod
144
    def _get_ids_from_json(data: typing.Any) -> typing.Tuple[typedefs.JsonRpcIdType, ...]:
145
        if not data:
146
            return ()
147
148
        if isinstance(data, typing.Mapping) and data.get('id') is not None:
149
            return (
150
                data['id'],
151
            )
152
153
        if isinstance(data, typing.Sequence):
154
            return tuple(
155
                item['id']
156
                for item in data
157
                if isinstance(item, typing.Mapping) and item.get('id') is not None
158
            )
159
160
        return ()
161
162
    async def _handle_ws_messages(self) -> None:
163
        assert self.ws_connect is not None
164
165
        while True:
166
            try:
167
                ws_msg: http_websocket.WSMessage = await self.ws_connect.receive(
168
                    timeout=self._timeout_for_data_receiving,
169
                )
170
            except asyncio.TimeoutError:
171
                if self._is_closed:
172
                    break
173
                else:
174
                    continue
175
176
            if ws_msg.type in (
177
                    http_websocket.WSMsgType.CLOSE,
178
                    http_websocket.WSMsgType.CLOSING,
179
                    http_websocket.WSMsgType.CLOSED,
180
            ):
181
                break
182
183
            if ws_msg.type != http_websocket.WSMsgType.TEXT:
184
                continue
185
186
            try:
187
                task = asyncio.create_task(self._handle_single_ws_message(ws_msg))
188
            except asyncio.CancelledError as e:
189
                error = errors.InternalError(utils.get_exc_message(e)).with_traceback()
190
                self._notify_all_about_error(error)
191
                break
192
            except Exception:
193
                logger.warning('Can\'t process WS message.', exc_info=True)
194
            else:
195
                # To avoid a task disappearing mid execution:
196
                self._background_tasks.add(task)
197
                task.add_done_callback(self._background_tasks.discard)
198
199
            if self._is_closed:
200
                break
201
202
    async def _check_ws_connection(self) -> None:
203
        assert self.ws_connect is not None
204
205
        try:
206
            while not self._is_closed:
207
                if self.ws_connect.closed:
208
                    error = errors.ServerError('Connection is closed')
209
                    self._notify_all_about_error(error)
210
                    break
211
212
                await asyncio.sleep(self._connection_check_interval)  # type: ignore
213
        except asyncio.CancelledError:
214
            pass
215
216
    async def _handle_single_ws_message(self, ws_msg: http_websocket.WSMessage) -> None:
217
        if ws_msg.type != http_websocket.WSMsgType.text:
218
            return
219
220
        try:
221
            json_response = json.loads(ws_msg.data)
222
        except Exception:
223
            logger.warning('Can\'t parse json.', exc_info=True)
224
            return
225
226
        if not json_response:
227
            return
228
229
        if isinstance(json_response, typing.Mapping):
230
            await self._handle_single_json_response(json_response, ws_msg=ws_msg)
231
            return
232
233
        if isinstance(json_response, typing.Sequence):
234
            await self._handle_json_responses(json_response, ws_msg=ws_msg)
235
            return
236
237
        logger.warning('Couldn\'t process the response.', extra={
238
            'json_response': json_response,
239
        })
240
241
    async def _handle_single_json_response(self, json_response: typing.Mapping, *, ws_msg: web_ws.WSMessage) -> None:
242
        if 'method' in json_response:
243
            if self._json_request_handler is not None:
244
                await self._json_request_handler(
245
                    ws_connect=self.ws_connect,
246
                    ws_msg=ws_msg,
247
                    json_request=json_response,
248
                )
249
        elif 'id' in json_response and json_response['id'] in self._pending:
250
            self._notify_about_result(json_response['id'], json_response)
251
        elif self._unprocessed_json_response_handler is not None:
252
            self._unprocessed_json_response_handler(
253
                ws_connect=self.ws_connect,
254
                ws_msg=ws_msg,
255
                json_response=json_response,
256
            )
257
258
    async def _handle_json_responses(self, json_responses: typing.Sequence, *, ws_msg: web_ws.WSMessage) -> None:
259
        if isinstance(json_responses[0], typing.Mapping) and 'method' in json_responses[0]:
260
            if self._json_request_handler is not None:
261
                await self._json_request_handler(ws_connect=self.ws_connect, ws_msg=ws_msg)
262
        else:
263
            response_ids = self._get_ids_from_json(json_responses)
264
265
            if response_ids:
266
                self._notify_about_results(response_ids, json_responses)
267
            elif self._unprocessed_json_response_handler is not None:
268
                self._unprocessed_json_response_handler(
269
                    ws_connect=self.ws_connect,
270
                    ws_msg=ws_msg,
271
                    json_response=json_responses,
272
                )
273
274
    def _notify_all_about_error(self, error: Exception) -> None:
275
        for future in self._pending.values():
276
            try:
277
                future.set_exception(error)
278
            except asyncio.InvalidStateError:
279
                pass
280
281
        self._pending.clear()
282
283
    def _notify_about_result(self, response_id: typedefs.JsonRpcIdType, json_response: typing.Mapping) -> None:
284
        future = self._pending.pop(response_id, None)
285
286
        if future is not None:
287
            future.set_result(json_response)
288
289
    def _notify_about_results(self,
290
                              response_ids: typing.Sequence[typedefs.JsonRpcIdType],
291
                              json_response: typing.Sequence) -> None:
292
        is_processed = False
293
294
        for response_id in response_ids:
295
            future = self._pending.pop(response_id, None)
296
297
            if future is not None and not is_processed:
298
                # We suppose that a batch result has the same ids that we sent.
299
                # And these ids have the same future.
300
301
                future.set_result(json_response)
302
                is_processed = True
303