Passed
Push — master ( 7dbe03...d2ae7d )
by Michael
05:22
created

BaseJsonRpcServer._load_middlewares()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 2
nop 1
1
import abc
2
import asyncio
3
import typing
4
from functools import partial
5
6
from .. import errors, protocol, typedefs, utils
7
8
9
__all__ = (
10
    'BaseJsonRpcServer',
11
)
12
13
14
class BaseJsonRpcServer(abc.ABC):
15
    methods: typing.MutableMapping[str, protocol.BaseJsonRpcMethod]
16
    middlewares: typing.Sequence[typing.Callable]
17
    json_serialize: typedefs.UnboundJSONEncoderType
18
    _middleware_chain: typedefs.UnboundSingleRequestProcessorType
19
20
    def __init__(self, *,
21
                 json_serialize: typedefs.JSONEncoderType = utils.json_serialize,
22
                 middlewares: typing.Sequence = (),
23
                 methods: typing.Optional[typing.Dict[str, protocol.BaseJsonRpcMethod]] = None) -> None:
24
        if methods is None:
25
            methods = {
26
                'get_method': protocol.JsonRpcMethod('', self.get_method),
27
                'get_methods': protocol.JsonRpcMethod('', self.get_methods),
28
            }
29
30
        self.methods = methods
31
32
        self.middlewares = middlewares
33
        self._load_middlewares()
34
35
        self.json_serialize = json_serialize  # type: ignore
36
37
    def add_method(self,
38
                   method: typing.Union[typedefs.ServerMethodDescriptionType], *,
39
                   replace: bool = False) -> protocol.BaseJsonRpcMethod:
40
        if not isinstance(method, protocol.BaseJsonRpcMethod):
41
            if callable(method):
42
                method = protocol.JsonRpcMethod('', method)
43
            else:
44
                method = protocol.JsonRpcMethod(*method)
45
46
        if not replace and method.name in self.methods:
47
            raise errors.InvalidParams(f'Method {method.name} has already been added.')
48
49
        self.methods[method.name] = method
50
51
        return method
52
53
    def add_methods(self,
54
                    methods: typing.Iterable[typedefs.ServerMethodDescriptionType], *,
55
                    replace: bool = False) -> typing.Tuple[protocol.BaseJsonRpcMethod, ...]:
56
        return tuple(
57
            self.add_method(method, replace=replace)
58
            for method in methods
59
        )
60
61
    async def call(self,
62
                   method_name: str, *,
63
                   args: typing.Optional[typing.Sequence] = None,
64
                   kwargs: typing.Optional[typing.Mapping] = None,
65
                   extra_args: typing.Optional[typing.Mapping] = None) -> typing.Any:
66
        if args is None:
67
            args = ()
68
69
        if kwargs is None:
70
            kwargs = {}
71
72
        if method_name not in self.methods:
73
            raise errors.MethodNotFound
74
75
        return await self.methods[method_name](args=args, kwargs=kwargs, extra_args=extra_args)
76
77
    def get_methods(self) -> typing.Mapping[str, typing.Mapping[str, typing.Any]]:
78
        return {
79
            name: {
80
                'doc': method.func.__doc__,
81
                'args': method.supported_args,
82
                'kwargs': method.supported_kwargs,
83
            }
84
            for name, method in self.methods.items()
85
        }
86
87
    def get_method(self, name: str) -> typing.Optional[typing.Mapping[str, typing.Any]]:
88
        method = self.methods.get(name)
89
90
        if not method:
91
            return None
92
93
        return {
94
            'doc': method.func.__doc__,
95
            'args': method.supported_args,
96
            'kwargs': method.supported_kwargs,
97
        }
98
99
    def _load_middlewares(self) -> None:
100
        self._middleware_chain = self._process_single_request  # type: ignore
101
102
        for middleware in reversed(self.middlewares):
103
            self._middleware_chain: typedefs.SingleRequestProcessorType = partial(  # type: ignore
104
                middleware,
105
                handler=self._middleware_chain,
106
            )
107
108
    async def _process_input_data(
109
            self,
110
            data: typing.Any, *,
111
            context: typing.MutableMapping[str, typing.Any],
112
    ) -> typing.Optional[typing.Union[typing.Mapping, typing.Tuple[typing.Mapping, ...]]]:
113
        if isinstance(data, typing.Sequence):
114
            if not data:
115
                return protocol.JsonRpcResponse(error=errors.InvalidRequest()).dump()
116
117
            json_responses = await asyncio.gather(
118
                *(
119
                    self._process_single_json_request(raw_rcp_request, context=context)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable raw_rcp_request does not seem to be defined.
Loading history...
120
                    for raw_rcp_request in data
121
                ),
122
                return_exceptions=True,
123
            )
124
125
            result = tuple(
126
                json_response
127
                for json_response in self._prepare_exceptions(json_responses)
128
                if json_response is not None
129
            )
130
131
            return result if result else None
132
133
        if isinstance(data, typing.Mapping):
134
            return await self._process_single_json_request(data, context=context)
135
136
        response = protocol.JsonRpcResponse(error=errors.InvalidRequest('Data must be a dict or an list.'))
137
        return response.dump()
138
139
    @staticmethod
140
    def _prepare_exceptions(values: typing.Iterable) -> typing.Iterable:
141
        for i, value in enumerate(values):
142
            if isinstance(value, errors.JsonRpcError):
143
                yield protocol.JsonRpcResponse(error=value)
144
            elif isinstance(value, Exception):
145
                raise value
146
            else:
147
                yield value
148
149
    async def _process_single_json_request(self,
150
                                           json_request: typing.Any, *,
151
                                           context: typing.MutableMapping[str, typing.Any],
152
                                           ) -> typing.Optional[typing.Mapping]:
153
        if not isinstance(json_request, typing.Mapping):
154
            return protocol.JsonRpcResponse(error=errors.InvalidRequest('Data must be a dict.')).dump()
155
156
        try:
157
            request = protocol.JsonRpcRequest.load(json_request, context=context)
158
        except errors.JsonRpcError as e:
159
            return protocol.JsonRpcResponse(id=json_request.get('id'), error=e).dump()
160
161
        response = await self._middleware_chain(request)
162
163
        if response.is_notification:
164
            return None
165
166
        return response.dump()
167
168
    async def _process_single_request(self, request: protocol.JsonRpcRequest) -> protocol.JsonRpcResponse:
169
        result, error = None, None
170
171
        try:
172
            result = await self.call(
173
                request.method_name,
174
                args=request.args,
175
                kwargs=request.kwargs,
176
                extra_args=request.extra_args,
177
            )
178
        except errors.JsonRpcError as e:
179
            error = e
180
181
        response = protocol.JsonRpcResponse(
182
            id=request.id,
183
            jsonrpc=request.jsonrpc,
184
            result=result,
185
            error=error,
186
        )
187
188
        return response
189