Passed
Push — master ( 239fba...767d9a )
by Michael
03:05
created

WsJsonRpcClient._handle_ws_message()   A

Complexity

Conditions 5

Size

Total Lines 18
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 18
rs 9.2333
c 0
b 0
f 0
cc 5
nop 2
1
import asyncio
2
import json
3
import logging
4
import typing
5
6
import aiohttp
7
8
from .base import BaseJsonRpcClient
9
from .. import errors, utils
10
11
12
__all__ = (
13
    'WsJsonRpcClient',
14
)
15
16
logger = logging.getLogger(__name__)
17
18
19
class WsJsonRpcClient(BaseJsonRpcClient):
20
    ws_connect = None
21
    notify_about_result: typing.Optional[typing.Callable] = None
22
    timeout: typing.Optional[int]
23
    ws_connect_kwargs: dict
24
    _pending: typing.Dict[typing.Any, asyncio.Future]
25
    _message_worker: typing.Optional[asyncio.Future] = None
26
27
    def __init__(self,
28
                 url: str, *,
29
                 session: typing.Optional[aiohttp.ClientSession] = None,
30
                 timeout: typing.Optional[int] = 5,
31
                 **ws_connect_kwargs) -> None:
32
        self.url = url
33
        self.session = session
34
        self._is_outer_session = session is not None
35
        self._pending = {}
36
        self.timeout = timeout
37
        self.ws_connect_kwargs = ws_connect_kwargs
38
39
    async def connect(self) -> None:
40
        if not self.session:
41
            self.session = aiohttp.ClientSession(json_serialize=self.json_serialize)
42
43
        try:
44
            self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
45
        except Exception:
46
            await self.disconnect()
47
            raise
48
49
        self._message_worker = asyncio.ensure_future(self._handle_ws_messages())
50
51
    async def disconnect(self) -> None:
52
        if self.ws_connect:
53
            await self.ws_connect.close()
54
55
        if not self._is_outer_session:
56
            await self.session.close()
57
58
        if self._message_worker:
59
            await self._message_worker
60
61
    async def send_json(self,
62
                        data: typing.Any, *,
63
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
64
        if without_response:
65
            await self.ws_connect.send_str(self.json_serialize(data))
66
            return None, None
67
68
        msg_ids = self._get_msg_ids_from_json(data)
69
        future = asyncio.Future()
70
71
        for msg_id in msg_ids:
72
            self._pending[msg_id] = future
73
74
        await self.ws_connect.send_str(self.json_serialize(data))
75
76
        if not msg_ids:
77
            return None, None
78
79
        if self.timeout is not None:
80
            future = asyncio.wait_for(future, timeout=self.timeout)
81
82
        result = await future
83
84
        return result, None
85
86
    def clear_pending(self) -> None:
87
        self._pending = {}
88
89
    @staticmethod
90
    def _get_msg_ids_from_json(data: typing.Any) -> typing.Optional[list]:
91
        if not data:
92
            return []
93
94
        if isinstance(data, dict) and data.get('id') is not None:
95
            return [data['id']]
96
97
        if isinstance(data, list):
98
            return [
99
                item['id']
100
                for item in data
101
                if item.get('id') is not None
102
            ]
103
104
        return []
105
106
    async def _handle_ws_messages(self) -> typing.NoReturn:
107
        while not self.ws_connect.closed:
108
            try:
109
                ws_msg = await self.ws_connect.receive()
110
                self._handle_ws_message(ws_msg)
111
            except asyncio.CancelledError as e:
112
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
113
                self._notify_all_about_error(error)
114
                raise
115
            except Exception as e:
116
                logger.exception(e)
117
118
    def _handle_ws_message(self, ws_msg: aiohttp.WSMessage) -> None:
119
        if ws_msg.type != aiohttp.WSMsgType.text:
120
            return
121
122
        json_response = json.loads(ws_msg.data)
123
124
        if isinstance(json_response, dict) and 'id' in json_response:
125
            self._notify_about_result(json_response['id'], json_response)
126
            return
127
128
        if isinstance(json_response, list):
129
            self._notify_about_results(
130
                [
131
                    item['id']
132
                    for item in json_response
133
                    if isinstance(item, dict) and 'id' in item
134
                ],
135
                json_response,
136
            )
137
138
    def _notify_all_about_error(self, error: errors.JsonRpcError) -> None:
139
        for future in self._pending.values():
140
            future.set_exception(error)
141
142
        self._pending = {}
143
144
    def _notify_about_result(self, msg_id: typing.Any, json_response: dict) -> None:
145
        future = self._pending.pop(msg_id, None)
146
147
        if future:
148
            future.set_result(json_response)
149
150
    def _notify_about_results(self, msg_ids: list, json_response: list) -> None:
151
        is_processed = False
152
153
        for msg_id in msg_ids:
154
            future = self._pending.pop(msg_id, None)
155
156
            if future and not is_processed:
157
                future.set_result(json_response)
158
                is_processed = True
159