Passed
Push — master ( 0d356b...652bd4 )
by Michael
11:41 queued 05:19
created

aiohttp_rpc.client.base.BaseJsonRpcClient.notify()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 4
1
import abc
2
import types
3
import typing
4
from functools import partial
5
6
from .. import constants, errors, protocol, utils
7
8
9
__all__ = (
10
    'BaseJsonRpcClient',
11
)
12
13
MethodDescription = typing.Union[str, list, tuple, protocol.JsonRpcRequest]
14
MethodDescriptionList = typing.Iterable[typing.Union[MethodDescription, protocol.JsonRpcBatchRequest]]
15
16
17
class BaseJsonRpcClient(abc.ABC):
18
    error_map: typing.Dict[int, errors.JsonRpcError] = {
19
        error.code: error
20
        for error in errors.DEFAULT_KNOWN_ERRORS
21
    }
22
    json_serialize: typing.Callable = utils.json_serialize
23
24
    async def __aenter__(self) -> 'BaseJsonRpcClient':
25
        await self.connect()
26
        return self
27
28
    async def __aexit__(self,
29
                        exc_type: typing.Optional[typing.Type[BaseException]],
30
                        exc_value: typing.Optional[BaseException],
31
                        traceback: typing.Optional[types.TracebackType]) -> None:
32
        await self.disconnect()
33
34
    def __getattr__(self, method_name: str) -> typing.Callable:
35
        return partial(self.call, method_name)
36
37
    @abc.abstractmethod
38
    async def connect(self) -> None:
39
        pass
40
41
    @abc.abstractmethod
42
    async def disconnect(self) -> None:
43
        pass
44
45
    async def call(self, method_name: str, *args, **kwargs) -> typing.Any:
46
        request = protocol.JsonRpcRequest(id=utils.get_random_id(), method_name=method_name, args=args, kwargs=kwargs)
47
        response = await self.direct_call(request)
48
49
        if response.error not in constants.EMPTY_VALUES:
50
            raise response.error
51
52
        return response.result
53
54
    async def notify(self, method_name: str, *args, **kwargs) -> None:
55
        request = protocol.JsonRpcRequest(method_name=method_name, args=args, kwargs=kwargs)
56
        await self.direct_call(request)
57
58
    async def batch(self,
59
                    method_descriptions: MethodDescriptionList, *,
60
                    save_order: bool = True) -> typing.Any:
61
        if isinstance(method_descriptions, protocol.JsonRpcBatchRequest):
62
            batch_request = method_descriptions
63
        else:
64
            requests = [
65
                self._parse_method_description(method_description)
66
                for method_description in method_descriptions
67
            ]
68
            batch_request = protocol.JsonRpcBatchRequest(requests=requests)
69
70
        batch_response = await self.direct_batch(batch_request)
71
72
        if save_order:
73
            return self._collect_batch_result(batch_request, batch_response)
74
        else:
75
            return [
76
                response.result if response.error in constants.EMPTY_VALUES else response.error
77
                for response in batch_response.responses
78
            ]
79
80
    async def batch_notify(self, method_descriptions: MethodDescriptionList) -> None:
81
        if isinstance(method_descriptions, protocol.JsonRpcBatchRequest):
82
            batch_request = method_descriptions
83
        else:
84
            requests = [
85
                self._parse_method_description(method_description, is_notification=True)
86
                for method_description in method_descriptions
87
            ]
88
            batch_request = protocol.JsonRpcBatchRequest(requests=requests)
89
90
        await self.direct_batch(batch_request)
91
92
    async def direct_call(self, request: protocol.JsonRpcRequest) -> typing.Optional[protocol.JsonRpcResponse]:
93
        json_response, context = await self.send_json(request.to_dict(), without_response=request.is_notification)
94
95
        if request.is_notification:
96
            return None
97
98
        response = protocol.JsonRpcResponse.from_dict(
99
            json_response,
100
            error_map=self.error_map,
101
            context=context,
102
        )
103
104
        return response
105
106
    async def direct_batch(self,
107
                           batch_request: protocol.JsonRpcBatchRequest,
108
                           ) -> typing.Optional[protocol.JsonRpcBatchResponse]:
109
        if not batch_request.requests:
110
            raise errors.InvalidRequest('You can not send an empty batch request.')
111
112
        is_notification = batch_request.is_notification
113
        json_response, context = await self.send_json(batch_request.to_list(), without_response=is_notification)
114
115
        if is_notification:
116
            return None
117
118
        if not json_response:
119
            raise errors.ParseError('Server returned an empty batch response.')
120
121
        return protocol.JsonRpcBatchResponse.from_list(json_response)
122
123
    @abc.abstractmethod
124
    async def send_json(self,
125
                        data: typing.Any, *,
126
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
127
        pass
128
129
    @staticmethod
130
    def _collect_batch_result(batch_request: protocol.JsonRpcBatchRequest,
131
                              batch_response: protocol.JsonRpcBatchResponse) -> list:
132
        unlinked_results = protocol.UnlinkedResults()
133
        responses_map = {}
134
135
        for response in batch_response.responses:
136
            if response.error in constants.EMPTY_VALUES:
137
                value = response.result
138
            else:
139
                value = response.error
140
141
            if response.id in constants.EMPTY_VALUES:
142
                unlinked_results.add(value)
143
                continue
144
145
            if response.id in responses_map:
146
                if isinstance(responses_map[response.id], protocol.DuplicatedResults):
147
                    responses_map[response.id].add(value)
148
                else:
149
                    responses_map[response.id] = protocol.DuplicatedResults(data=[
150
                        responses_map[response.id],
151
                        value,
152
                    ])
153
            else:
154
                responses_map[response.id] = value
155
156
        if not unlinked_results:
157
            unlinked_results = None
158
159
        result = []
160
161
        for request in batch_request.requests:
162
            if request.is_notification:
163
                result.append(unlinked_results)
164
                continue
165
166
            result.append(responses_map.get(request.id, unlinked_results))
167
168
        return result
169
170
    @staticmethod
171
    def _parse_method_description(method_description: MethodDescription, *,
172
                                  is_notification: bool = False) -> protocol.JsonRpcRequest:
173
        if isinstance(method_description, protocol.JsonRpcRequest):
174
            return method_description
175
176
        request_id = constants.NOTHING if is_notification else utils.get_random_id()
177
178
        if isinstance(method_description, str):
179
            return protocol.JsonRpcRequest(
180
                id=request_id,
181
                method_name=method_description,
182
            )
183
184
        if len(method_description) == 1:
185
            return protocol.JsonRpcRequest(
186
                id=request_id,
187
                method_name=method_description[0],
188
            )
189
190
        if len(method_description) == 2:
191
            return protocol.JsonRpcRequest(
192
                id=request_id,
193
                method_name=method_description[0],
194
                params=method_description[1],
195
            )
196
197
        if len(method_description) == 3:
198
            return protocol.JsonRpcRequest(
199
                id=request_id,
200
                method_name=method_description[0],
201
                args=method_description[1],
202
                kwargs=method_description[2],
203
            )
204
205
        raise errors.InvalidParams('Use string or list (length less than or equal to 3).')
206