Passed
Pull Request — master (#5)
by Michael
05:44
created

BaseJsonRpcClient._collect_batch_result()   C

Complexity

Conditions 9

Size

Total Lines 40
Code Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 28
dl 0
loc 40
rs 6.6666
c 0
b 0
f 0
cc 9
nop 2
1
import abc
2
import types
3
import typing
4
from functools import partial
5
6
from .. import constants, errors, protocol, utils
7
8
9
__all__ = (
10
    'BaseJsonRpcClient',
11
)
12
13
MethodType = typing.Union[str, list, tuple, protocol.CalledJsonRpcMethod]
14
15
16
class BaseJsonRpcClient(abc.ABC):
17
    error_map: typing.Dict[int, errors.JsonRpcError] = {
18
        error.code: error
19
        for error in errors.DEFAULT_KNOWN_ERRORS
20
    }
21
    json_serialize: typing.Callable = utils.json_serialize
22
23
    async def __aenter__(self) -> 'BaseJsonRpcClient':
24
        await self.connect()
25
        return self
26
27
    async def __aexit__(self,
28
                        exc_type: typing.Optional[typing.Type[BaseException]],
29
                        exc_value: typing.Optional[BaseException],
30
                        traceback: typing.Optional[types.TracebackType]) -> None:
31
        await self.disconnect()
32
33
    def __getattr__(self, method) -> typing.Callable:
34
        return partial(self.call, method)
35
36
    @abc.abstractmethod
37
    async def connect(self) -> None:
38
        pass
39
40
    @abc.abstractmethod
41
    async def disconnect(self) -> None:
42
        pass
43
44
    async def call(self, method: str, *args, **kwargs) -> typing.Any:
45
        request = protocol.JsonRpcRequest(msg_id=utils.get_random_msg_id(), method=method, args=args, kwargs=kwargs)
46
        response = await self.direct_call(request)
47
48
        if response.error not in constants.EMPTY_VALUES:
49
            raise response.error
50
51
        return response.result
52
53
    async def notify(self, method: str, *args, **kwargs) -> None:
54
        request = protocol.JsonRpcRequest(method=method, args=args, kwargs=kwargs)
55
        await self.direct_call(request, without_response=True)
56
57
    async def batch(self,
58
                    methods: typing.Iterable[MethodType], *,
59
                    save_order: bool = True) -> typing.Any:
60
        requests = [self._parse_method(method) for method in methods]
61
        batch_request = protocol.JsonRpcBatchRequest(requests=requests)
62
        batch_response = await self.direct_batch(batch_request)
63
64
        if save_order:
65
            return self._collect_batch_result(batch_request, batch_response)
66
        else:
67
            return [
68
                response.result if response.error in constants.EMPTY_VALUES else response.error
69
                for response in batch_response.responses
70
            ]
71
72
    async def batch_notify(self, methods: typing.Iterable[MethodType]) -> None:
73
        requests = [self._parse_method(method, is_notification=True) for method in methods]
74
        batch_request = protocol.JsonRpcBatchRequest(requests=requests)
75
        await self.direct_batch(batch_request, without_response=True)
76
77
    async def direct_call(self,
78
                          request: protocol.JsonRpcRequest, *,
79
                          without_response: bool = False) -> typing.Optional[protocol.JsonRpcResponse]:
80
        json_response, context = await self.send_json(request.to_dict(), without_response=without_response)
81
82
        if without_response:
83
            return None
84
85
        response = protocol.JsonRpcResponse.from_dict(
86
            json_response,
87
            error_map=self.error_map,
88
            context=context,
89
        )
90
91
        return response
92
93
    async def direct_batch(self,
94
                           batch_request: protocol.JsonRpcBatchRequest, *,
95
                           without_response: bool = False) -> typing.Optional[protocol.JsonRpcBatchResponse]:
96
        if not batch_request.requests:
97
            raise errors.InvalidRequest('You can not send an empty batch request.')
98
99
        json_response, context = await self.send_json(batch_request.to_list(), without_response=without_response)
100
101
        if without_response:
102
            return None
103
104
        if not json_response:
105
            raise errors.ParseError('Server returned an empty batch response.')
106
107
        return protocol.JsonRpcBatchResponse.from_list(json_response)
108
109
    @abc.abstractmethod
110
    async def send_json(self,
111
                        data: typing.Any, *,
112
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
113
        pass
114
115
    @staticmethod
116
    def _collect_batch_result(batch_request: protocol.JsonRpcBatchRequest,
117
                              batch_response: protocol.JsonRpcBatchResponse) -> list:
118
        unlinked_results = protocol.UnlinkedResults()
119
        responses_map = {}
120
121
        for response in batch_response.responses:
122
            if response.error in constants.EMPTY_VALUES:
123
                value = response.result
124
            else:
125
                value = response.error
126
127
            if response.msg_id in constants.EMPTY_VALUES:
128
                unlinked_results.add(value)
129
                continue
130
131
            if response.msg_id in responses_map:
132
                if isinstance(responses_map[response.msg_id], protocol.DuplicatedResults):
133
                    responses_map[response.msg_id].add(value)
134
                else:
135
                    responses_map[response.msg_id] = protocol.DuplicatedResults(data=[
136
                        responses_map[response.msg_id],
137
                        value,
138
                    ])
139
140
            responses_map[response.msg_id] = value
141
142
        if not unlinked_results:
143
            unlinked_results = None
144
145
        result = []
146
147
        for request in batch_request.requests:
148
            if request.is_notification:
149
                result.append(unlinked_results)
150
                continue
151
152
            result.append(responses_map.get(request.msg_id, unlinked_results))
153
154
        return result
155
156
    @staticmethod
157
    def _parse_method(method: MethodType, *, is_notification: bool = False) -> protocol.JsonRpcRequest:
158
        if isinstance(method, protocol.CalledJsonRpcMethod):
159
            called_method = method
160
        else:
161
            called_method = protocol.CalledJsonRpcMethod.from_params(method)
162
            called_method.is_notification = is_notification
163
164
        if called_method.msg_id in constants.EMPTY_VALUES and not called_method.is_notification:
165
            called_method.msg_id = utils.get_random_msg_id()
166
167
        return called_method.to_request()
168