Test Failed
Pull Request — master (#5)
by Michael
15:16
created

BaseJsonRpcClient._collect_batch_result()   B

Complexity

Conditions 5

Size

Total Lines 26
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 26
rs 8.9833
c 0
b 0
f 0
cc 5
nop 2
1
import abc
2
import types
3
import typing
4
import uuid
5
from dataclasses import dataclass
6
from functools import partial
7
8
from .. import constants, errors, protocol, utils
9
10
11
__all__ = (
12
    'BaseJsonRpcClient',
13
    'UnlinkedResults',
14
)
15
16
17
class BaseJsonRpcClient(abc.ABC):
18
    error_map: typing.Dict[int, errors.JsonRpcError] = {
19
        error.code: error
20
        for error in errors.DEFAULT_KNOWN_ERRORS
21
    }
22
    json_serialize: typing.Callable = utils.json_serialize
23
24
    async def __aenter__(self) -> 'BaseJsonRpcClient':
25
        await self.connect()
26
        return self
27
28
    async def __aexit__(self,
29
                        exc_type: typing.Optional[typing.Type[BaseException]],
30
                        exc_value: typing.Optional[BaseException],
31
                        traceback: typing.Optional[types.TracebackType]) -> None:
32
        await self.disconnect()
33
34
    @abc.abstractmethod
35
    async def connect(self) -> None:
36
        pass
37
38
    @abc.abstractmethod
39
    async def disconnect(self) -> None:
40
        pass
41
42
    async def call(self, method: str, *args, **kwargs) -> typing.Any:
43
        request = protocol.JsonRpcRequest(msg_id=str(uuid.uuid4()), method=method, args=args, kwargs=kwargs)
44
        response = await self.direct_call(request)
45
46
        if response.error not in constants.EMPTY_VALUES:
47
            raise response.error
48
49
        return response.result
50
51
    async def notify(self, method: str, *args, **kwargs) -> None:
52
        request = protocol.JsonRpcRequest(method=method, args=args, kwargs=kwargs)
53
        await self.send_json(request.to_dict(), without_response=True)
54
55
    async def batch(self, methods: typing.Iterable[typing.Union[str, list, tuple]]) -> typing.Any:
56
        batch_request = protocol.JsonRpcBatchRequest(requests=[self._parse_batch_method(method) for method in methods])
57
        batch_response = await self.direct_batch(batch_request)
58
        return self._collect_batch_result(batch_request, batch_response)
59
60
    async def batch_notify(self, methods: typing.Iterable[typing.Union[str, list, tuple]]) -> None:
61
        batch_request = protocol.JsonRpcBatchRequest(
62
            requests=[self._parse_batch_method(method, is_notification=True) for method in methods],
63
        )
64
        await self.send_json(batch_request.to_list(), without_response=True)
65
66
    async def direct_call(self, request: protocol.JsonRpcRequest) -> protocol.JsonRpcResponse:
67
        json_response, context = await self.send_json(request.to_dict())
68
        response = protocol.JsonRpcResponse.from_dict(
69
            json_response,
70
            error_map=self.error_map,
71
            context=context,
72
        )
73
        return response
74
75
    async def direct_batch(self, batch_request: protocol.JsonRpcBatchRequest) -> protocol.JsonRpcBatchResponse:
76
        json_response, context = await self.send_json(batch_request.to_list())
77
        return protocol.JsonRpcBatchResponse.from_list(json_response)
78
79
    @abc.abstractmethod
80
    async def send_json(self,
81
                        data: typing.Any, *,
82
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
83
        pass
84
85
    def __getattr__(self, method) -> typing.Callable:
86
        return partial(self.call, method)
87
88
    @staticmethod
89
    def _collect_batch_result(batch_request: protocol.JsonRpcBatchRequest,
90
                              batch_response: protocol.JsonRpcBatchResponse) -> list:
91
        unlinked_results = UnlinkedResults(data=[])
92
        responses_map = {}
93
94
        for response in batch_response.responses:
95
            value = response.error or response.result
96
97
            if response.msg_id in constants.EMPTY_VALUES:
98
                unlinked_results.data.append(value)
99
                continue
100
101
            responses_map[response.msg_id] = value
102
103
        unlinked_results = unlinked_results.compile()
104
        result = []
105
106
        for request in batch_request.requests:
107
            if request.is_notification:
108
                result.append(unlinked_results)
109
                continue
110
111
            result.append(responses_map.get(request.msg_id, unlinked_results))
112
113
        return result
114
115
    @staticmethod
116
    def _parse_batch_method(batch_method: typing.Union[str, list, tuple], *,
117
                            is_notification: bool = False) -> protocol.JsonRpcRequest:
118
        msg_id = constants.NOTHING if is_notification else str(uuid.uuid4())
119
120
        if isinstance(batch_method, str):
121
            return protocol.JsonRpcRequest(
122
                msg_id=msg_id,
123
                method=batch_method,
124
            )
125
126
        if len(batch_method) == 1:
127
            return protocol.JsonRpcRequest(
128
                msg_id=msg_id,
129
                method=batch_method[0],
130
            )
131
132
        if len(batch_method) == 2:
133
            return protocol.JsonRpcRequest(
134
                msg_id=msg_id,
135
                method=batch_method[0],
136
                params=batch_method[1],
137
            )
138
139
        if len(batch_method) == 3:
140
            return protocol.JsonRpcRequest(
141
                msg_id=msg_id,
142
                method=batch_method[0],
143
                args=batch_method[1],
144
                kwargs=batch_method[2],
145
            )
146
147
        raise errors.InvalidParams('Use string or list (length less than or equal to 3).')
148
149
150
@dataclass
151
class UnlinkedResults:
152
    data: list
153
154
    def compile(self) -> typing.Any:
155
        if not self.data:
156
            return None
157
158
        if len(self.data) == 1:
159
            return self.data[0]
160
161
        return self
162