Passed
Push — master ( 1e73b7...ae2d2b )
by Michael
04:53
created

BaseJsonRpcClient._collect_batch_result()   B

Complexity

Conditions 8

Size

Total Lines 36
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 25
dl 0
loc 36
rs 7.3333
c 0
b 0
f 0
cc 8
nop 2
1
import abc
2
import types
3
import typing
4
from functools import partial
5
6
from .. import errors, protocol, typedefs, utils
7
8
9
__all__ = (
10
    'BaseJsonRpcClient',
11
)
12
13
14
class BaseJsonRpcClient(abc.ABC):
15
    error_map: typing.Dict[int, typing.Type[errors.JsonRpcError]] = {
16
        error.code: error
17
        for error in errors.DEFAULT_KNOWN_ERRORS
18
    }
19
    json_serialize: typing.Callable = utils.json_serialize
20
21
    async def __aenter__(self) -> 'BaseJsonRpcClient':
22
        await self.connect()
23
        return self
24
25
    async def __aexit__(self,
26
                        exc_type: typing.Optional[typing.Type[BaseException]],
27
                        exc_value: typing.Optional[BaseException],
28
                        traceback: typing.Optional[types.TracebackType]) -> None:
29
        await self.disconnect()
30
31
    def __getattr__(self, method_name: str) -> typing.Callable:
32
        return partial(self.call, method_name)
33
34
    @abc.abstractmethod
35
    async def connect(self) -> None:
36
        pass
37
38
    @abc.abstractmethod
39
    async def disconnect(self) -> None:
40
        pass
41
42
    async def call(self, method_name: str, *args, **kwargs) -> typing.Any:
43
        request = protocol.JsonRpcRequest(id=utils.get_random_id(), method_name=method_name, args=args, kwargs=kwargs)
44
        response = await self.direct_call(request)
45
46
        assert response is not None  # Because it isn't a notification
47
48
        if response.error is not None:
49
            raise response.error
50
51
        return response.result
52
53
    async def notify(self, method_name: str, *args, **kwargs) -> None:
54
        request = protocol.JsonRpcRequest(method_name=method_name, args=args, kwargs=kwargs)
55
        await self.direct_call(request)
56
57
    async def batch(self,
58
                    method_descriptions: typedefs.MethodDescriptionsType, *,
59
                    save_order: bool = True) -> typing.Any:
60
        if isinstance(method_descriptions, protocol.JsonRpcBatchRequest):
61
            batch_request = method_descriptions
62
        else:
63
            batch_request = protocol.JsonRpcBatchRequest(requests=[
64
                self._parse_method_description(method_description)
65
                for method_description in method_descriptions
66
            ])
67
68
        batch_response = await self.direct_batch(batch_request)
69
70
        assert batch_response is not None  # Because it isn't a notification
71
72
        if save_order:
73
            return self._collect_batch_result(batch_request, batch_response)
74
        else:
75
            return [
76
                response.result if response.error is None else response.error
77
                for response in batch_response.responses
78
            ]
79
80
    async def batch_notify(self, method_descriptions: typedefs.MethodDescriptionsType) -> None:
81
        if isinstance(method_descriptions, protocol.JsonRpcBatchRequest):
82
            batch_request = method_descriptions
83
        else:
84
            batch_request = protocol.JsonRpcBatchRequest(requests=[
85
                self._parse_method_description(method_description, is_notification=True)
86
                for method_description in method_descriptions
87
            ])
88
89
        await self.direct_batch(batch_request)
90
91
    async def direct_call(self,
92
                          request: protocol.JsonRpcRequest,
93
                          **kwargs) -> typing.Optional[protocol.JsonRpcResponse]:
94
        json_response, context = await self.send_json(
95
            request.to_dict(),
96
            without_response=request.is_notification,
97
            **kwargs,
98
        )
99
100
        if request.is_notification:
101
            return None
102
103
        response = protocol.JsonRpcResponse.from_dict(
104
            json_response,
105
            error_map=self.error_map,
106
            context=context,
107
        )
108
109
        return response
110
111
    async def direct_batch(self,
112
                           batch_request: protocol.JsonRpcBatchRequest,
113
                           **kwargs) -> typing.Optional[protocol.JsonRpcBatchResponse]:
114
        if not batch_request.requests:
115
            raise errors.InvalidRequest('You can not send an empty batch request.')
116
117
        is_notification = batch_request.is_notification
118
119
        json_response, context = await self.send_json(
120
            batch_request.to_list(),
121
            without_response=is_notification,
122
            **kwargs,
123
        )
124
125
        if is_notification:
126
            return None
127
128
        if not json_response:
129
            raise errors.ParseError('Server returned an empty batch response.')
130
131
        return protocol.JsonRpcBatchResponse.from_list(json_response)
132
133
    @abc.abstractmethod
134
    async def send_json(self,
135
                        data: typing.Any, *,
136
                        without_response: bool = False,
137
                        **kwargs) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
138
        pass
139
140
    @staticmethod
141
    def _collect_batch_result(batch_request: protocol.JsonRpcBatchRequest,
142
                              batch_response: protocol.JsonRpcBatchResponse) -> list:
143
        unlinked_results = protocol.UnlinkedResults()
144
        responses_map: typing.Dict[typing.Any, typing.Any] = {}
145
146
        for response in batch_response.responses:
147
            if response.error is None:
148
                value = response.result
149
            else:
150
                value = response.error
151
152
            if response.id is None:
153
                unlinked_results.add(value)
154
                continue
155
156
            if response.id in responses_map:
157
                if isinstance(responses_map[response.id], protocol.DuplicatedResults):
158
                    responses_map[response.id].add(value)
159
                else:
160
                    responses_map[response.id] = protocol.DuplicatedResults(data=[
161
                        responses_map[response.id],
162
                        value,
163
                    ])
164
            else:
165
                responses_map[response.id] = value
166
167
        result = []
168
169
        for request in batch_request.requests:
170
            if request.is_notification:
171
                result.append(unlinked_results or None)
172
            else:
173
                result.append(responses_map.get(request.id, unlinked_results or None))
174
175
        return result
176
177
    @staticmethod
178
    def _parse_method_description(method_description: typedefs.MethodDescriptionType, *,
179
                                  is_notification: bool = False) -> protocol.JsonRpcRequest:
180
        if isinstance(method_description, protocol.JsonRpcRequest):
181
            return method_description
182
183
        request_id = None if is_notification else utils.get_random_id()
184
185
        if isinstance(method_description, str):
186
            return protocol.JsonRpcRequest(
187
                id=request_id,
188
                method_name=method_description,
189
            )
190
191
        if len(method_description) == 1:
192
            return protocol.JsonRpcRequest(
193
                id=request_id,
194
                method_name=method_description[0],
195
            )
196
197
        if len(method_description) == 2:
198
            return protocol.JsonRpcRequest(
199
                id=request_id,
200
                method_name=method_description[0],
201
                params=method_description[1],
202
            )
203
204
        if len(method_description) == 3:
205
            return protocol.JsonRpcRequest(
206
                id=request_id,
207
                method_name=method_description[0],
208
                args=method_description[1],
209
                kwargs=method_description[2],
210
            )
211
212
        raise errors.InvalidParams('Use string or list (length less than or equal to 3).')
213