Test Failed
Push — master ( c58336...8dd0a6 )
by Michael
03:57
created

aiohttp_rpc.client.BaseJsonRpcClient.direct_batch()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 2
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
import abc
2
import asyncio
3
import json
4
import logging
5
import types
6
import typing
7
import uuid
8
from dataclasses import dataclass
9
from functools import partial
10
11
import aiohttp
12
13
from . import constants, errors, utils
14
from .protocol import JsonRpcRequest, JsonRpcResponse
15
16
17
__all__ = (
18
    'BaseJsonRpcClient',
19
    'JsonRpcClient',
20
    'WsJsonRpcClient',
21
    'UnlinkedResults',
22
)
23
24
logger = logging.getLogger(__name__)
25
26
27
@dataclass
28
class UnlinkedResults:
29
    data: list
30
31
    def compile(self) -> typing.Any:
32
        if not self.data:
33
            return None
34
35
        if len(self.data) == 1:
36
            return self.data[0]
37
38
        return self
39
40
41
class BaseJsonRpcClient(abc.ABC):
42
    error_map: typing.Dict[int, errors.JsonRpcError] = {
43
        error.code: error
44
        for error in errors.DEFAULT_KNOWN_ERRORS
45
    }
46
    json_serialize: typing.Callable = utils.json_serialize
47
48
    async def __aenter__(self) -> 'BaseJsonRpcClient':
49
        await self.connect()
50
        return self
51
52
    async def __aexit__(self,
53
                        exc_type: typing.Optional[typing.Type[BaseException]],
54
                        exc_value: typing.Optional[BaseException],
55
                        traceback: typing.Optional[types.TracebackType]) -> None:
56
        await self.disconnect()
57
58
    @abc.abstractmethod
59
    async def connect(self) -> None:
60
        pass
61
62
    @abc.abstractmethod
63
    async def disconnect(self) -> None:
64
        pass
65
66
    async def call(self, method: str, *args, **kwargs) -> typing.Any:
67
        request = JsonRpcRequest(msg_id=str(uuid.uuid4()), method=method, args=args, kwargs=kwargs)
68
        response = await self.direct_call(request)
69
70
        if response.error:
71
            raise response.error
72
73
        return response.result
74
75
    async def notify(self, method: str, *args, **kwargs) -> None:
76
        request = JsonRpcRequest(method=method, args=args, kwargs=kwargs)
77
        await self.send_json(request.to_dict(), without_response=True)
78
79
    async def batch(self, methods: typing.Iterable[typing.Union[str, list, tuple]]) -> typing.Any:
80
        requests = [self._parse_batch_method(method) for method in methods]
81
        responses = await self.direct_batch(requests)
82
        unlinked_results = UnlinkedResults(data=[])
83
        responses_map = {}
84
85
        for response in responses:
86
            if response.msg_id is None or response.msg_id is constants.NOTHING:
87
                unlinked_results.data.append(response.error or response.result)
88
                continue
89
90
            responses_map[response.msg_id] = response.error or response.result
91
92
        unlinked_results = unlinked_results.compile()
93
        result = []
94
95
        for request in requests:
96
            if request.msg_id is None or request.msg_id is constants.NOTHING:
97
                result.append(unlinked_results)
98
                continue
99
100
            result.append(responses_map.get(request.msg_id, unlinked_results))
101
102
        return result
103
104
    async def batch_notify(self, methods: typing.Iterable[typing.Union[str, list, tuple]]) -> None:
105
        requests = [self._parse_batch_method(method, is_notification=True) for method in methods]
106
        data = [request.to_dict() for request in requests]
107
        await self.send_json(data, without_response=True)
108
109
    async def direct_call(self, request: JsonRpcRequest) -> JsonRpcResponse:
110
        json_response, context = await self.send_json(request.to_dict())
111
        response = JsonRpcResponse.from_dict(
112
            json_response,
113
            error_map=self.error_map,
114
            context=context,
115
        )
116
        return response
117
118
    async def direct_batch(self, requests: typing.List[JsonRpcRequest]) -> typing.List[JsonRpcResponse]:
119
        data = [request.to_dict() for request in requests]
120
        json_response, context = await self.send_json(data)
121
122
        return [
123
            JsonRpcResponse.from_dict(item, error_map=self.error_map, context=context)
124
            for item in json_response
125
        ]
126
127
    @abc.abstractmethod
128
    async def send_json(self,
129
                        data: typing.Any, *,
130
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
131
        pass
132
133
    def __getattr__(self, method) -> typing.Callable:
134
        return partial(self.call, method)
135
136
    @staticmethod
137
    def _parse_batch_method(batch_method: typing.Union[str, list, tuple], *,
138
                            is_notification: bool = False) -> JsonRpcRequest:
139
        msg_id = constants.NOTHING if is_notification else str(uuid.uuid4())
140
141
        if isinstance(batch_method, str):
142
            return JsonRpcRequest(msg_id=msg_id, method=batch_method)
143
144
        if len(batch_method) == 1:
145
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0])
146
147
        if len(batch_method) == 2:
148
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0], params=batch_method[1])
149
150
        if len(batch_method) == 3:
151
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0], args=batch_method[1], kwargs=batch_method[2])
152
153
        raise errors.InvalidParams('Use string or list (length less than or equal to 3).')
154
155
156
class JsonRpcClient(BaseJsonRpcClient):
157
    url: str
158
    session: typing.Optional[aiohttp.ClientSession]
159
    request_kwargs: dict
160
    _is_outer_session: bool
161
162
    def __init__(self,
163
                 url: str, *,
164
                 session: typing.Optional[aiohttp.ClientSession] = None,
165
                 **request_kwargs) -> None:
166
        self.url = url
167
        self.session = session
168
        self.request_kwargs = request_kwargs
169
        self._is_outer_session = session is not None
170
171
    async def connect(self) -> None:
172
        if not self.session:
173
            self.session = aiohttp.ClientSession(json_serialize=self.json_serialize)
174
175
    async def disconnect(self) -> None:
176
        if not self._is_outer_session:
177
            await self.session.close()
178
179
    async def send_json(self,
180
                        data: typing.Any, *,
181
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
182
        http_response = await self.session.post(self.url, json=data, **self.request_kwargs)
183
184
        try:
185
            json_response = await http_response.json()
186
        except aiohttp.ContentTypeError as e:
187
            raise errors.ParseError(utils.get_exc_message(e)) from e
188
189
        return json_response, {'http_response': http_response}
190
191
192
class WsJsonRpcClient(BaseJsonRpcClient):
193
    ws_connect = None
194
    notify_about_result: typing.Optional[typing.Callable] = None
195
    timeout: typing.Optional[int]
196
    ws_connect_kwargs: dict
197
    _pending: typing.Dict[typing.Any, asyncio.Future]
198
    _message_worker: typing.Optional[asyncio.Future] = None
199
200
    def __init__(self,
201
                 url: str, *,
202
                 session: typing.Optional[aiohttp.ClientSession] = None,
203
                 timeout: typing.Optional[int] = 5,
204
                 **ws_connect_kwargs) -> None:
205
        self.url = url
206
        self.session = session
207
        self._is_outer_session = session is not None
208
        self._pending = {}
209
        self.timeout = timeout
210
        self.ws_connect_kwargs = ws_connect_kwargs
211
212
    async def connect(self) -> None:
213
        if not self.session:
214
            self.session = aiohttp.ClientSession(json_serialize=self.json_serialize)
215
216
        try:
217
            self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
218
        except Exception:
219
            await self.disconnect()
220
            raise
221
222
        self._message_worker = asyncio.ensure_future(self._handle_ws_messages())
223
224
    async def disconnect(self) -> None:
225
        if self.ws_connect:
226
            await self.ws_connect.close()
227
228
        if not self._is_outer_session:
229
            await self.session.close()
230
231
        if self._message_worker:
232
            await self._message_worker
233
234
    async def send_json(self,
235
                        data: typing.Any, *,
236
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
237
        if without_response:
238
            await self.ws_connect.send_str(self.json_serialize(data))
239
            return None, None
240
241
        msg_ids = None
242
243
        if isinstance(data, dict):
244
            msg_ids = (data['id'],)
245
        elif isinstance(data, list):
246
            msg_ids = tuple(item['id'] for item in data)
247
248
        if not msg_ids:
249
            await self.ws_connect.send_str(self.json_serialize(data))
250
            return None, None
251
252
        future = asyncio.Future()
253
254
        for msg_id in msg_ids:
255
            self._pending[msg_id] = future
256
257
        await self.ws_connect.send_str(self.json_serialize(data))
258
259
        if self.timeout is not None:
260
            future = asyncio.wait_for(future, timeout=self.timeout)
261
262
        result = await future
263
264
        return result, None
265
266
    async def _handle_ws_messages(self) -> typing.NoReturn:
267
        while not self.ws_connect.closed:
268
            try:
269
                ws_msg = await self.ws_connect.receive()
270
                self._handle_ws_message(ws_msg)
271
            except asyncio.CancelledError as e:
272
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
273
                self._notify_all_about_error(error)
274
                raise
275
            except Exception as e:
276
                logger.exception(e)
277
278
    def _handle_ws_message(self, ws_msg: aiohttp.WSMessage) -> None:
279
        if ws_msg.type != aiohttp.WSMsgType.text:
280
            return
281
282
        json_response = json.loads(ws_msg.data)
283
284
        if isinstance(json_response, dict) and 'id' in json_response:
285
            self._notify_about_result(json_response['id'], json_response)
286
            return
287
288
        if isinstance(json_response, list):
289
            self._notify_about_results(
290
                [
291
                    item['id']
292
                    for item in json_response
293
                    if isinstance(item, dict) and 'id' in item
294
                ],
295
                json_response,
296
            )
297
298
    def _notify_all_about_error(self, error: errors.JsonRpcError) -> None:
299
        for future in self._pending.values():
300
            future.set_exception(error)
301
302
        self._pending = {}
303
304
    def _notify_about_result(self, msg_id: typing.Any, json_response: dict) -> None:
305
        future = self._pending.pop(msg_id, None)
306
307
        if future:
308
            future.set_result(json_response)
309
310
    def _notify_about_results(self, msg_ids: list, json_response: list) -> None:
311
        is_processed = False
312
313
        for msg_id in msg_ids:
314
            future = self._pending.pop(msg_id, None)
315
316
            if future and not is_processed:
317
                future.set_result(json_response)
318
                is_processed = True
319