Passed
Push — master ( 9a8b07...1de29a )
by Michael
03:47
created

BaseJsonRpcClient.__aenter__()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import abc
2
import types
3
import typing
4
import uuid
5
from dataclasses import dataclass
6
from functools import partial
7
8
from .. import constants, errors, utils
9
from ..protocol import JsonRpcRequest, JsonRpcResponse
10
11
12
__all__ = (
13
    'BaseJsonRpcClient',
14
    'UnlinkedResults',
15
)
16
17
18
class BaseJsonRpcClient(abc.ABC):
19
    error_map: typing.Dict[int, errors.JsonRpcError] = {
20
        error.code: error
21
        for error in errors.DEFAULT_KNOWN_ERRORS
22
    }
23
    json_serialize: typing.Callable = utils.json_serialize
24
25
    async def __aenter__(self) -> 'BaseJsonRpcClient':
26
        await self.connect()
27
        return self
28
29
    async def __aexit__(self,
30
                        exc_type: typing.Optional[typing.Type[BaseException]],
31
                        exc_value: typing.Optional[BaseException],
32
                        traceback: typing.Optional[types.TracebackType]) -> None:
33
        await self.disconnect()
34
35
    @abc.abstractmethod
36
    async def connect(self) -> None:
37
        pass
38
39
    @abc.abstractmethod
40
    async def disconnect(self) -> None:
41
        pass
42
43
    async def call(self, method: str, *args, **kwargs) -> typing.Any:
44
        request = JsonRpcRequest(msg_id=str(uuid.uuid4()), method=method, args=args, kwargs=kwargs)
45
        response = await self.direct_call(request)
46
47
        if response.error:
48
            raise response.error
49
50
        return response.result
51
52
    async def notify(self, method: str, *args, **kwargs) -> None:
53
        request = JsonRpcRequest(method=method, args=args, kwargs=kwargs)
54
        await self.send_json(request.to_dict(), without_response=True)
55
56
    async def batch(self, methods: typing.Iterable[typing.Union[str, list, tuple]]) -> typing.Any:
57
        requests = [self._parse_batch_method(method) for method in methods]
58
        responses = await self.direct_batch(requests)
59
        unlinked_results = UnlinkedResults(data=[])
60
        responses_map = {}
61
62
        for response in responses:
63
            if response.msg_id is None or response.msg_id is constants.NOTHING:
64
                unlinked_results.data.append(response.error or response.result)
65
                continue
66
67
            responses_map[response.msg_id] = response.error or response.result
68
69
        unlinked_results = unlinked_results.compile()
70
        result = []
71
72
        for request in requests:
73
            if request.msg_id is None or request.msg_id is constants.NOTHING:
74
                result.append(unlinked_results)
75
                continue
76
77
            result.append(responses_map.get(request.msg_id, unlinked_results))
78
79
        return result
80
81
    async def batch_notify(self, methods: typing.Iterable[typing.Union[str, list, tuple]]) -> None:
82
        requests = [self._parse_batch_method(method, is_notification=True) for method in methods]
83
        data = [request.to_dict() for request in requests]
84
        await self.send_json(data, without_response=True)
85
86
    async def direct_call(self, request: JsonRpcRequest) -> JsonRpcResponse:
87
        json_response, context = await self.send_json(request.to_dict())
88
        response = JsonRpcResponse.from_dict(
89
            json_response,
90
            error_map=self.error_map,
91
            context=context,
92
        )
93
        return response
94
95
    async def direct_batch(self, requests: typing.List[JsonRpcRequest]) -> typing.List[JsonRpcResponse]:
96
        data = [request.to_dict() for request in requests]
97
        json_response, context = await self.send_json(data)
98
99
        return [
100
            JsonRpcResponse.from_dict(item, error_map=self.error_map, context=context)
101
            for item in json_response
102
        ]
103
104
    @abc.abstractmethod
105
    async def send_json(self,
106
                        data: typing.Any, *,
107
                        without_response: bool = False) -> typing.Tuple[typing.Any, typing.Optional[dict]]:
108
        pass
109
110
    def __getattr__(self, method) -> typing.Callable:
111
        return partial(self.call, method)
112
113
    @staticmethod
114
    def _parse_batch_method(batch_method: typing.Union[str, list, tuple], *,
115
                            is_notification: bool = False) -> JsonRpcRequest:
116
        msg_id = constants.NOTHING if is_notification else str(uuid.uuid4())
117
118
        if isinstance(batch_method, str):
119
            return JsonRpcRequest(msg_id=msg_id, method=batch_method)
120
121
        if len(batch_method) == 1:
122
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0])
123
124
        if len(batch_method) == 2:
125
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0], params=batch_method[1])
126
127
        if len(batch_method) == 3:
128
            return JsonRpcRequest(msg_id=msg_id, method=batch_method[0], args=batch_method[1], kwargs=batch_method[2])
129
130
        raise errors.InvalidParams('Use string or list (length less than or equal to 3).')
131
132
133
@dataclass
134
class UnlinkedResults:
135
    data: list
136
137
    def compile(self) -> typing.Any:
138
        if not self.data:
139
            return None
140
141
        if len(self.data) == 1:
142
            return self.data[0]
143
144
        return self
145