Passed
Push — master ( 239fba...767d9a )
by Michael
03:05
created

BaseJsonRpcClient._collect_batch_result()   A

Complexity

Conditions 5

Size

Total Lines 25
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 18
dl 0
loc 25
rs 9.0333
c 0
b 0
f 0
cc 5
nop 2
1
import abc
2
import types
3
import typing
4
import uuid
5
from dataclasses import dataclass
6
from functools import partial
7
8
from .. import constants, errors, utils
9
from ..protocol import JsonRpcRequest, JsonRpcResponse
10
11
12
__all__ = (
13
    'BaseJsonRpcClient',
14
    'UnlinkedResults',
15
)
16
17
18
class BaseJsonRpcClient(abc.ABC):
19
    error_map: typing.Dict[int, errors.JsonRpcError] = {
20
        error.code: error
21
        for error in errors.DEFAULT_KNOWN_ERRORS
22
    }
23
    json_serialize: typing.Callable = utils.json_serialize
24
25
    async def __aenter__(self) -> 'BaseJsonRpcClient':
26
        await self.connect()
27
        return self
28
29
    async def __aexit__(self,
30
                        exc_type: typing.Optional[typing.Type[BaseException]],
31
                        exc_value: typing.Optional[BaseException],
32
                        traceback: typing.Optional[types.TracebackType]) -> None:
33
        await self.disconnect()
34
35
    @abc.abstractmethod
36
    async def connect(self) -> None:
37
        pass
38
39
    @abc.abstractmethod
40
    async def disconnect(self) -> None:
41
        pass
42
43
    async def call(self, method: str, *args, **kwargs) -> typing.Any:
44
        request = JsonRpcRequest(msg_id=str(uuid.uuid4()), method=method, args=args, kwargs=kwargs)
45
        response = await self.direct_call(request)
46
47
        if response.error:
48
            raise response.error
49
50
        return response.result
51
52
    async def notify(self, method: str, *args, **kwargs) -> None:
53
        request = JsonRpcRequest(method=method, args=args, kwargs=kwargs)
54
        await self.send_json(request.to_dict(), without_response=True)
55
56
    async def batch(self, methods: typing.Iterable[typing.Union[str, list, tuple]]) -> typing.Any:
57
        requests = [self._parse_batch_method(method) for method in methods]
58
        responses = await self.direct_batch(requests)
59
        return self._collect_batch_result(requests, responses)
60
61
    async def batch_notify(self, methods: typing.Iterable[typing.Union[str, list, tuple]]) -> None:
62
        requests = [self._parse_batch_method(method, is_notification=True) for method in methods]
63
        data = [request.to_dict() for request in requests]
64
        await self.send_json(data, without_response=True)
65
66
    async def direct_call(self, request: JsonRpcRequest) -> JsonRpcResponse:
67
        json_response, context = await self.send_json(request.to_dict())
68
        response = JsonRpcResponse.from_dict(
69
            json_response,
70
            error_map=self.error_map,
71
            context=context,
72
        )
73
        return response
74
75
    async def direct_batch(self, requests: typing.List[JsonRpcRequest]) -> typing.List[JsonRpcResponse]:
76
        data = [request.to_dict() for request in requests]
77
        json_response, context = await self.send_json(data)
78
79
        return [
80
            JsonRpcResponse.from_dict(item, error_map=self.error_map, context=context)
81
            for item in json_response
82
        ]
83
84
    @abc.abstractmethod
85
    async def send_json(self,
86
                        data: typing.Any, *,
87
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
88
        pass
89
90
    def __getattr__(self, method) -> typing.Callable:
91
        return partial(self.call, method)
92
93
    @staticmethod
94
    def _collect_batch_result(requests: typing.List[JsonRpcRequest], responses: typing.List[JsonRpcResponse]) -> list:
95
        unlinked_results = UnlinkedResults(data=[])
96
        responses_map = {}
97
98
        for response in responses:
99
            value = response.error or response.result
100
101
            if response.msg_id in constants.EMPTY_VALUES:
102
                unlinked_results.data.append(value)
103
                continue
104
105
            responses_map[response.msg_id] = value
106
107
        unlinked_results = unlinked_results.compile()
108
        result = []
109
110
        for request in requests:
111
            if request.msg_id in constants.EMPTY_VALUES:
112
                result.append(unlinked_results)
113
                continue
114
115
            result.append(responses_map.get(request.msg_id, unlinked_results))
116
117
        return result
118
119
    @staticmethod
120
    def _parse_batch_method(batch_method: typing.Union[str, list, tuple], *,
121
                            is_notification: bool = False) -> JsonRpcRequest:
122
        msg_id = constants.NOTHING if is_notification else str(uuid.uuid4())
123
124
        if isinstance(batch_method, str):
125
            return JsonRpcRequest(msg_id=msg_id, method=batch_method)
126
127
        if len(batch_method) == 1:
128
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0])
129
130
        if len(batch_method) == 2:
131
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0], params=batch_method[1])
132
133
        if len(batch_method) == 3:
134
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0], args=batch_method[1], kwargs=batch_method[2])
135
136
        raise errors.InvalidParams('Use string or list (length less than or equal to 3).')
137
138
139
@dataclass
140
class UnlinkedResults:
141
    data: list
142
143
    def compile(self) -> typing.Any:
144
        if not self.data:
145
            return None
146
147
        if len(self.data) == 1:
148
            return self.data[0]
149
150
        return self
151