Passed
Push — master ( 9a8b07...1de29a )
by Michael
03:47
created

aiohttp_rpc.client.ws   A

Complexity

Total Complexity 32

Size/Duplication

Total Lines 146
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 32
eloc 108
dl 0
loc 146
rs 9.84
c 0
b 0
f 0

9 Methods

Rating   Name   Duplication   Size   Complexity  
A WsJsonRpcClient._notify_about_result() 0 5 2
A WsJsonRpcClient._notify_all_about_error() 0 5 2
A WsJsonRpcClient.__init__() 0 11 1
A WsJsonRpcClient._handle_ws_messages() 0 11 4
A WsJsonRpcClient._handle_ws_message() 0 18 5
A WsJsonRpcClient.disconnect() 0 9 4
A WsJsonRpcClient._notify_about_results() 0 9 4
B WsJsonRpcClient.send_json() 0 31 7
A WsJsonRpcClient.connect() 0 11 3
1
import asyncio
2
import json
3
import logging
4
import typing
5
6
import aiohttp
7
8
from .base import BaseJsonRpcClient
9
from .. import errors, utils
10
11
12
__all__ = (
13
    'WsJsonRpcClient',
14
)
15
16
logger = logging.getLogger(__name__)
17
18
19
class WsJsonRpcClient(BaseJsonRpcClient):
20
    ws_connect = None
21
    notify_about_result: typing.Optional[typing.Callable] = None
22
    timeout: typing.Optional[int]
23
    ws_connect_kwargs: dict
24
    _pending: typing.Dict[typing.Any, asyncio.Future]
25
    _message_worker: typing.Optional[asyncio.Future] = None
26
27
    def __init__(self,
28
                 url: str, *,
29
                 session: typing.Optional[aiohttp.ClientSession] = None,
30
                 timeout: typing.Optional[int] = 5,
31
                 **ws_connect_kwargs) -> None:
32
        self.url = url
33
        self.session = session
34
        self._is_outer_session = session is not None
35
        self._pending = {}
36
        self.timeout = timeout
37
        self.ws_connect_kwargs = ws_connect_kwargs
38
39
    async def connect(self) -> None:
40
        if not self.session:
41
            self.session = aiohttp.ClientSession(json_serialize=self.json_serialize)
42
43
        try:
44
            self.ws_connect = await self.session.ws_connect(self.url, **self.ws_connect_kwargs)
45
        except Exception:
46
            await self.disconnect()
47
            raise
48
49
        self._message_worker = asyncio.ensure_future(self._handle_ws_messages())
50
51
    async def disconnect(self) -> None:
52
        if self.ws_connect:
53
            await self.ws_connect.close()
54
55
        if not self._is_outer_session:
56
            await self.session.close()
57
58
        if self._message_worker:
59
            await self._message_worker
60
61
    async def send_json(self,
62
                        data: typing.Any, *,
63
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
64
        if without_response:
65
            await self.ws_connect.send_str(self.json_serialize(data))
66
            return None, None
67
68
        msg_ids = None
69
70
        if isinstance(data, dict):
71
            msg_ids = (data['id'],)
72
        elif isinstance(data, list):
73
            msg_ids = tuple(item['id'] for item in data)
74
75
        if not msg_ids:
76
            await self.ws_connect.send_str(self.json_serialize(data))
77
            return None, None
78
79
        future = asyncio.Future()
80
81
        for msg_id in msg_ids:
82
            self._pending[msg_id] = future
83
84
        await self.ws_connect.send_str(self.json_serialize(data))
85
86
        if self.timeout is not None:
87
            future = asyncio.wait_for(future, timeout=self.timeout)
88
89
        result = await future
90
91
        return result, None
92
93
    async def _handle_ws_messages(self) -> typing.NoReturn:
94
        while not self.ws_connect.closed:
95
            try:
96
                ws_msg = await self.ws_connect.receive()
97
                self._handle_ws_message(ws_msg)
98
            except asyncio.CancelledError as e:
99
                error = errors.ServerError(utils.get_exc_message(e)).with_traceback()
100
                self._notify_all_about_error(error)
101
                raise
102
            except Exception as e:
103
                logger.exception(e)
104
105
    def _handle_ws_message(self, ws_msg: aiohttp.WSMessage) -> None:
106
        if ws_msg.type != aiohttp.WSMsgType.text:
107
            return
108
109
        json_response = json.loads(ws_msg.data)
110
111
        if isinstance(json_response, dict) and 'id' in json_response:
112
            self._notify_about_result(json_response['id'], json_response)
113
            return
114
115
        if isinstance(json_response, list):
116
            self._notify_about_results(
117
                [
118
                    item['id']
119
                    for item in json_response
120
                    if isinstance(item, dict) and 'id' in item
121
                ],
122
                json_response,
123
            )
124
125
    def _notify_all_about_error(self, error: errors.JsonRpcError) -> None:
126
        for future in self._pending.values():
127
            future.set_exception(error)
128
129
        self._pending = {}
130
131
    def _notify_about_result(self, msg_id: typing.Any, json_response: dict) -> None:
132
        future = self._pending.pop(msg_id, None)
133
134
        if future:
135
            future.set_result(json_response)
136
137
    def _notify_about_results(self, msg_ids: list, json_response: list) -> None:
138
        is_processed = False
139
140
        for msg_id in msg_ids:
141
            future = self._pending.pop(msg_id, None)
142
143
            if future and not is_processed:
144
                future.set_result(json_response)
145
                is_processed = True
146