Passed
Push — master ( 738c54...1e73b7 )
by Michael
04:59
created

WsJsonRpcClient._handle_json_responses()   B

Complexity

Conditions 6

Size

Total Lines 14
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 14
rs 8.6666
c 0
b 0
f 0
cc 6
nop 4
1
import asyncio
2
import json
3
import logging
4
import typing
5
6
from aiohttp import ClientSession, client_ws, http_websocket, web_ws
7
8
from .base import BaseJsonRpcClient
9
from .. import errors, utils
10
11
12
__all__ = (
13
    'WsJsonRpcClient',
14
)
15
16
logger = logging.getLogger(__name__)
17
18
WSConnectType = typing.Union[client_ws.ClientWebSocketResponse, web_ws.WebSocketResponse]
19
20
21
class WsJsonRpcClient(BaseJsonRpcClient):
22
    url: typing.Optional[str]
23
    ws_connect: typing.Optional[WSConnectType]
24
    timeout: typing.Optional[int]
25
    ws_connect_kwargs: dict
26
    _pending: typing.Dict[typing.Any, asyncio.Future]
27
    _message_worker: typing.Optional[asyncio.Future] = None
28
    _session_is_outer: bool
29
    _ws_connect_is_outer: bool
30
    _json_request_handler: typing.Optional[typing.Callable] = None
31
    _unprocessed_json_response_handler: typing.Optional[typing.Callable] = None
32
33
    def __init__(self,
34
                 url: typing.Optional[str] = None, *,
35
                 session: typing.Optional[ClientSession] = None,
36
                 ws_connect: typing.Optional[WSConnectType] = None,
37
                 timeout: typing.Optional[int] = 5,
38
                 json_request_handler: typing.Optional[typing.Callable] = None,
39
                 unprocessed_json_response_handler: typing.Optional[typing.Callable] = None,
40
                 **ws_connect_kwargs) -> None:
41
        assert (session is not None) or (url is not None and session is None) or (ws_connect is not None)
42
43
        self.url = url
44
        self.timeout = timeout
45
46
        self.session = session
47
        self._session_is_outer = session is not None
48
49
        self.ws_connect = ws_connect
50
        self.ws_connect_kwargs = ws_connect_kwargs
51
        self._ws_connect_is_outer = ws_connect is not None
52
53
        self._pending = {}
54
        self._json_request_handler = json_request_handler
55
        self._unprocessed_json_response_handler = unprocessed_json_response_handler
56
57
    async def connect(self) -> None:
58
        if not self.session and not self.ws_connect:
59
            self.session = ClientSession(json_serialize=self.json_serialize)
60
61
        if self.ws_connect is None:
62
            assert self.url is not None and self.session is not None
63
64
            try:
65
                self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
66
            except Exception:
67
                await self.disconnect()
68
                raise
69
70
        self._message_worker = asyncio.ensure_future(self._handle_ws_messages())
71
72
    async def disconnect(self) -> None:
73
        if self.ws_connect and not self._ws_connect_is_outer:
74
            await self.ws_connect.close()
75
76
        if self.session and not self._session_is_outer:
77
            await self.session.close()
78
79
        if self._message_worker:
80
            await self._message_worker
81
82
    async def send_json(self,
83
                        data: typing.Any, *,
84
                        without_response: bool = False,
85
                        **kwargs) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
86
        assert self.ws_connect is not None
87
88
        if without_response:
89
            await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
90
            return None, None
91
92
        request_ids = self._get_ids_from_json(data)
93
        future: asyncio.Future = asyncio.Future()
94
95
        for request_id in request_ids:
96
            self._pending[request_id] = future
97
98
        await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
99
100
        if not request_ids:
101
            return None, None
102
103
        if self.timeout is not None:
104
            future = asyncio.wait_for(future, timeout=self.timeout)
105
106
        result = await future
107
108
        return result, None
109
110
    def clear_pending(self) -> None:
111
        self._pending.clear()
112
113
    @staticmethod
114
    def _get_ids_from_json(data: typing.Any) -> list:
115
        if not data:
116
            return []
117
118
        if isinstance(data, dict) and data.get('id') is not None:
119
            return [data['id']]
120
121
        if isinstance(data, list):
122
            return [
123
                item['id']
124
                for item in data
125
                if isinstance(item, dict) and item.get('id') is not None
126
            ]
127
128
        return []
129
130
    async def _handle_ws_messages(self) -> None:
131
        assert self.ws_connect is not None
132
133
        ws_msg: http_websocket.WSMessage
134
135
        async for ws_msg in self.ws_connect:
136
            if ws_msg.type != http_websocket.WSMsgType.TEXT:
137
                continue
138
139
            try:
140
                asyncio.ensure_future(self._handle_single_ws_message(ws_msg))
141
            except asyncio.CancelledError as e:
142
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
143
                self._notify_all_about_error(error)
144
                raise
145
            except Exception:
146
                logger.warning('Can not process WS message.', exc_info=True)
147
148
    async def _handle_single_ws_message(self, ws_msg: http_websocket.WSMessage) -> None:
149
        if ws_msg.type != http_websocket.WSMsgType.text:
150
            return
151
152
        try:
153
            json_response = json.loads(ws_msg.data)
154
        except Exception:
155
            logger.warning('Can not parse json.', exc_info=True)
156
            return
157
158
        if not json_response:
159
            return
160
161
        if isinstance(json_response, dict):
162
            await self._handle_single_json_response(json_response, ws_msg=ws_msg)
163
            return
164
165
        if isinstance(json_response, list):
166
            await self._handle_json_responses(json_response, ws_msg=ws_msg)
167
            return
168
169
        logger.warning('Couldn\'t process the response.', extra={
170
            'json_response': json_response,
171
        })
172
173
    async def _handle_single_json_response(self, json_response: dict, *, ws_msg: web_ws.WSMessage) -> None:
174
        if 'method' in json_response:
175
            if self._json_request_handler:
176
                await self._json_request_handler(
177
                    ws_connect=self.ws_connect,
178
                    ws_msg=ws_msg,
179
                    json_request=json_response,
180
                )
181
        elif 'id' in json_response and json_response['id'] in self._pending:
182
            self._notify_about_result(json_response['id'], json_response)
183
        elif self._unprocessed_json_response_handler:
184
            self._unprocessed_json_response_handler(
185
                ws_connect=self.ws_connect,
186
                ws_msg=ws_msg,
187
                json_response=json_response,
188
            )
189
190
    async def _handle_json_responses(self, json_responses: list, *, ws_msg: web_ws.WSMessage) -> None:
191
        if isinstance(json_responses[0], dict) and 'method' in json_responses[0]:
192
            if self._json_request_handler:
193
                await self._json_request_handler(ws_connect=self.ws_connect, ws_msg=ws_msg)
194
        else:
195
            response_ids = self._get_ids_from_json(json_responses)
196
197
            if response_ids:
198
                self._notify_about_results(response_ids, json_responses)
199
            elif self._unprocessed_json_response_handler:
200
                self._unprocessed_json_response_handler(
201
                    ws_connect=self.ws_connect,
202
                    ws_msg=ws_msg,
203
                    json_response=json_responses,
204
                )
205
206
    def _notify_all_about_error(self, error: errors.JsonRpcError) -> None:
207
        for future in self._pending.values():
208
            future.set_exception(error)
209
210
        self.clear_pending()
211
212
    def _notify_about_result(self, response_id: typing.Any, json_response: dict) -> None:
213
        future = self._pending.pop(response_id, None)
214
215
        if future:
216
            future.set_result(json_response)
217
218
    def _notify_about_results(self, response_ids: list, json_response: list) -> None:
219
        is_processed = False
220
221
        for response_id in response_ids:
222
            future = self._pending.pop(response_id, None)
223
224
            if future and not is_processed:
225
                future.set_result(json_response)
226
                is_processed = True
227