Passed
Pull Request — master (#13)
by Michael
03:50
created

WsJsonRpcClient.__init__()   A

Complexity

Conditions 1

Size

Total Lines 26
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 26
rs 9.352
c 0
b 0
f 0
cc 1
nop 10

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import asyncio
2
import json
3
import logging
4
import typing
5
6
from aiohttp import ClientSession, http_websocket, web_ws
7
8
from .base import BaseJsonRpcClient
9
from .. import errors, typedefs, utils
10
11
12
__all__ = (
13
    'WsJsonRpcClient',
14
)
15
16
logger = logging.getLogger(__name__)
17
18
19
class WsJsonRpcClient(BaseJsonRpcClient):
20
    url: typing.Optional[str]
21
    ws_connect: typing.Optional[typedefs.WSConnectType]
22
    ws_connect_kwargs: dict
23
    timeout: typing.Optional[int]
24
    _connection_check_interval: typing.Optional[int]
25
    _pending: typing.Dict[typing.Any, asyncio.Future]
26
    _message_worker: typing.Optional[asyncio.Future] = None
27
    _check_worker: typing.Optional[asyncio.Future] = None
28
    _session_is_outer: bool
29
    _ws_connect_is_outer: bool
30
    _json_request_handler: typing.Optional[typing.Callable] = None
31
    _unprocessed_json_response_handler: typing.Optional[typing.Callable] = None
32
    _background_tasks: typing.Set
33
34
    def __init__(self,
35
                 url: typing.Optional[str] = None, *,
36
                 session: typing.Optional[ClientSession] = None,
37
                 ws_connect: typing.Optional[typedefs.WSConnectType] = None,
38
                 timeout: typing.Optional[int] = 5,
39
                 connection_check_interval: typing.Optional[int] = 1,
40
                 json_request_handler: typing.Optional[typing.Callable] = None,
41
                 unprocessed_json_response_handler: typing.Optional[typing.Callable] = None,
42
                 **ws_connect_kwargs) -> None:
43
        assert (session is not None) or (url is not None and session is None) or (ws_connect is not None)
44
45
        self.url = url
46
        self.timeout = timeout
47
        self._connection_check_interval = connection_check_interval
48
49
        self.session = session
50
        self._session_is_outer = session is not None  # We don't close an outer session.
51
52
        self.ws_connect = ws_connect
53
        self.ws_connect_kwargs = ws_connect_kwargs
54
        self._ws_connect_is_outer = ws_connect is not None  # We don't close an outer WS connection.
55
56
        self._pending = {}
57
        self._json_request_handler = json_request_handler
58
        self._unprocessed_json_response_handler = unprocessed_json_response_handler
59
        self._background_tasks = set()
60
61
    async def connect(self) -> None:
62
        if self.session is None and self.ws_connect is None:
63
            self.session = ClientSession(json_serialize=self.json_serialize)
64
65
        if self.ws_connect is None:
66
            assert self.url is not None and self.session is not None
67
68
            try:
69
                self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
70
            except Exception:
71
                await self.disconnect()
72
                raise
73
74
        self._message_worker = asyncio.create_task(self._handle_ws_messages())
75
76
        if self._connection_check_interval is not None:
77
            self._check_worker = asyncio.create_task(self._check_ws_connection())
78
79
    async def disconnect(self) -> None:
80
        if self.ws_connect is not None and not self._ws_connect_is_outer:
81
            await self.ws_connect.close()
82
83
        if self.session is not None and not self._session_is_outer:
84
            await self.session.close()
85
86
        if self._message_worker is not None:
87
            await self._message_worker
88
89
        if self._check_worker is not None:
90
            await self._check_worker
91
92
    async def send_json(self,
93
                        data: typing.Any, *,
94
                        without_response: bool = False,
95
                        **kwargs) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
96
        assert self.ws_connect is not None
97
98
        if without_response:
99
            try:
100
                await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
101
            except ConnectionResetError as e:
102
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
103
                self._notify_all_about_error(error)
104
                raise error
105
106
            return None, None
107
108
        request_ids = self._get_ids_from_json(data)
109
        future: asyncio.Future = asyncio.Future()
110
111
        for request_id in request_ids:
112
            self._pending[request_id] = future
113
114
        try:
115
            await self.ws_connect.send_str(self.json_serialize(data), **kwargs)
116
        except ConnectionResetError as e:
117
            error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
118
            self._notify_all_about_error(error)
119
            raise error
120
121
        if not request_ids:
122
            return None, None
123
124
        if self.timeout is not None:
125
            future = asyncio.wait_for(future, timeout=self.timeout)  # type: ignore
126
127
        result = await future
128
129
        return result, None
130
131
    @staticmethod
132
    def _get_ids_from_json(data: typing.Any) -> typing.Tuple[typedefs.JsonRpcIdType, ...]:
133
        if not data:
134
            return ()
135
136
        if isinstance(data, typing.Mapping) and data.get('id') is not None:
137
            return (
138
                data['id'],
139
            )
140
141
        if isinstance(data, typing.Sequence):
142
            return tuple(
143
                item['id']
144
                for item in data
145
                if isinstance(item, typing.Mapping) and item.get('id') is not None
146
            )
147
148
        return ()
149
150
    async def _handle_ws_messages(self) -> None:
151
        assert self.ws_connect is not None
152
153
        ws_msg_types_to_stop = (
154
            http_websocket.WSMsgType.CLOSE,
155
            http_websocket.WSMsgType.CLOSING,
156
            http_websocket.WSMsgType.CLOSED,
157
            http_websocket.WSMsgType.ERROR,
158
        )
159
160
        ws_msg: http_websocket.WSMessage
161
162
        async for ws_msg in self.ws_connect:
163
            if ws_msg.type == http_websocket.WSMsgType.TEXT:
164
                try:
165
                    task = asyncio.create_task(self._handle_single_ws_message(ws_msg))
166
                except asyncio.CancelledError as e:
167
                    error = errors.InternalError(utils.get_exc_message(e)).with_traceback()
168
                    self._notify_all_about_error(error)
169
                    break
170
                except Exception:
171
                    logger.warning('Can\'t process WS message.', exc_info=True)
172
                else:
173
                    # To avoid a task disappearing mid execution:
174
                    self._background_tasks.add(task)
175
                    task.add_done_callback(self._background_tasks.discard)
176
            elif ws_msg.type in ws_msg_types_to_stop:
177
                error = errors.ServerError('Connection is closed')
178
                self._notify_all_about_error(error)
179
                break
180
181
    async def _check_ws_connection(self) -> None:
182
        assert self.ws_connect is not None
183
184
        while True:
185
            if self.ws_connect.closed:
186
                error = errors.ServerError('Connection is closed')
187
                self._notify_all_about_error(error)
188
                break
189
190
            await asyncio.sleep(self._connection_check_interval)
191
192
    async def _handle_single_ws_message(self, ws_msg: http_websocket.WSMessage) -> None:
193
        if ws_msg.type != http_websocket.WSMsgType.text:
194
            return
195
196
        try:
197
            json_response = json.loads(ws_msg.data)
198
        except Exception:
199
            logger.warning('Can\'t parse json.', exc_info=True)
200
            return
201
202
        if not json_response:
203
            return
204
205
        if isinstance(json_response, typing.Mapping):
206
            await self._handle_single_json_response(json_response, ws_msg=ws_msg)
207
            return
208
209
        if isinstance(json_response, typing.Sequence):
210
            await self._handle_json_responses(json_response, ws_msg=ws_msg)
211
            return
212
213
        logger.warning('Couldn\'t process the response.', extra={
214
            'json_response': json_response,
215
        })
216
217
    async def _handle_single_json_response(self, json_response: typing.Mapping, *, ws_msg: web_ws.WSMessage) -> None:
218
        if 'method' in json_response:
219
            if self._json_request_handler is not None:
220
                await self._json_request_handler(
221
                    ws_connect=self.ws_connect,
222
                    ws_msg=ws_msg,
223
                    json_request=json_response,
224
                )
225
        elif 'id' in json_response and json_response['id'] in self._pending:
226
            self._notify_about_result(json_response['id'], json_response)
227
        elif self._unprocessed_json_response_handler is not None:
228
            self._unprocessed_json_response_handler(
229
                ws_connect=self.ws_connect,
230
                ws_msg=ws_msg,
231
                json_response=json_response,
232
            )
233
234
    async def _handle_json_responses(self, json_responses: typing.Sequence, *, ws_msg: web_ws.WSMessage) -> None:
235
        if isinstance(json_responses[0], typing.Mapping) and 'method' in json_responses[0]:
236
            if self._json_request_handler is not None:
237
                await self._json_request_handler(ws_connect=self.ws_connect, ws_msg=ws_msg)
238
        else:
239
            response_ids = self._get_ids_from_json(json_responses)
240
241
            if response_ids:
242
                self._notify_about_results(response_ids, json_responses)
243
            elif self._unprocessed_json_response_handler is not None:
244
                self._unprocessed_json_response_handler(
245
                    ws_connect=self.ws_connect,
246
                    ws_msg=ws_msg,
247
                    json_response=json_responses,
248
                )
249
250
    def _notify_all_about_error(self, error: Exception) -> None:
251
        for future in self._pending.values():
252
            try:
253
                future.set_exception(error)
254
            except asyncio.InvalidStateError:
255
                pass
256
257
        self._pending.clear()
258
259
    def _notify_about_result(self, response_id: typedefs.JsonRpcIdType, json_response: typing.Mapping) -> None:
260
        future = self._pending.pop(response_id, None)
261
262
        if future is not None:
263
            future.set_result(json_response)
264
265
    def _notify_about_results(self,
266
                              response_ids: typing.Sequence[typedefs.JsonRpcIdType],
267
                              json_response: typing.Sequence) -> None:
268
        is_processed = False
269
270
        for response_id in response_ids:
271
            future = self._pending.pop(response_id, None)
272
273
            if future is not None and not is_processed:
274
                # We suppose that a batch result has the same ids that we sent.
275
                # And these ids have the same future.
276
277
                future.set_result(json_response)
278
                is_processed = True
279