Total Complexity | 20 |
Total Lines | 362 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # -*- coding: utf-8 -*- |
||
|
|||
2 | # MIT License |
||
3 | # |
||
4 | # Copyright (c) 2021 Pincer |
||
5 | # |
||
6 | # Permission is hereby granted, free of charge, to any person obtaining |
||
7 | # a copy of this software and associated documentation files |
||
8 | # (the "Software"), to deal in the Software without restriction, |
||
9 | # including without limitation the rights to use, copy, modify, merge, |
||
10 | # publish, distribute, sublicense, and/or sell copies of the Software, |
||
11 | # and to permit persons to whom the Software is furnished to do so, |
||
12 | # subject to the following conditions: |
||
13 | # |
||
14 | # The above copyright notice and this permission notice shall be |
||
15 | # included in all copies or substantial portions of the Software. |
||
16 | # |
||
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
||
18 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
||
19 | # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
||
20 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
||
21 | # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
||
22 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
||
23 | # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
||
24 | |||
25 | from __future__ import annotations |
||
26 | |||
27 | import logging |
||
28 | from asyncio import iscoroutinefunction |
||
29 | from typing import Optional, Any, Union, Dict, Tuple, List |
||
30 | |||
31 | from pincer import __package__ |
||
1 ignored issue
–
show
|
|||
32 | from pincer._config import events |
||
33 | from pincer.core.dispatch import GatewayDispatch |
||
34 | from pincer.core.gateway import Dispatcher |
||
35 | from pincer.core.http import HTTPClient |
||
36 | from pincer.exceptions import InvalidEventName |
||
37 | from pincer.objects import User |
||
38 | from pincer.utils.extraction import get_index |
||
39 | from pincer.utils.insertion import should_pass_cls |
||
40 | from pincer.utils.types import Coro |
||
41 | |||
42 | _log = logging.getLogger(__package__) |
||
43 | |||
44 | middleware_type = Optional[Union[Coro, Tuple[str, List[Any], Dict[str, Any]]]] |
||
1 ignored issue
–
show
|
|||
45 | |||
46 | _events: Dict[str, Optional[Union[str, Coro]]] = {} |
||
47 | |||
48 | for event in events: |
||
49 | event_final_executor = f"on_{event}" |
||
50 | |||
51 | # Event middleware for the library. |
||
52 | # Function argument is a payload (GatewayDispatch). |
||
53 | |||
54 | # The function must return a string which |
||
55 | # contains the main event key. |
||
56 | |||
57 | # As second value a list with arguments, |
||
58 | # and thee third value value must be a dictionary. |
||
59 | # The last two are passed on as *args and **kwargs. |
||
60 | |||
61 | # NOTE: These return values must be passed as a tuple! |
||
62 | _events[event] = event_final_executor |
||
63 | |||
64 | # The registered event by the client. Do not manually overwrite. |
||
65 | _events[event_final_executor] = None |
||
66 | |||
67 | |||
68 | def middleware(call: str, *, override: bool = False): |
||
69 | """ |
||
1 ignored issue
–
show
|
|||
70 | Middleware are methods which can be registered with this decorator. |
||
71 | These methods are invoked before any ``on_`` event. |
||
72 | As the ``on_`` event is the final call. |
||
73 | |||
74 | A default call exists for all events, but some might already be in |
||
75 | use by the library. |
||
76 | |||
77 | If you know what you are doing, you can override these default |
||
78 | middleware methods by passing the override parameter. |
||
79 | |||
80 | The method to which this decorator is registered must be a coroutine, |
||
81 | and it must return a tuple with the following format\: |
||
82 | |||
83 | .. code-block:: python |
||
84 | |||
85 | tuple( |
||
86 | key for next middleware or final event [str], |
||
87 | args for next middleware/event which will be passed as *args |
||
88 | [list(Any)], |
||
89 | kwargs for next middleware/event which will be passed as |
||
90 | **kwargs [dict(Any)] |
||
91 | ) |
||
92 | |||
93 | One parameter is passed to the middleware. This parameter is the |
||
94 | payload parameter which is of type :class:`~.core.dispatch.GatewayDispatch`. |
||
95 | This contains the response from the discord API. |
||
96 | |||
97 | :Implementation example: |
||
98 | |||
99 | .. code-block:: pycon |
||
100 | |||
101 | >>> @middleware("ready", override=True) |
||
102 | >>> async def custom_ready(payload: GatewayDispatch): |
||
103 | >>> return "on_ready", [User.from_dict(payload.data.get("user"))] |
||
104 | |||
105 | >>> @Client.event |
||
106 | >>> async def on_ready(bot: User): |
||
107 | >>> print(f"Signed in as {bot}") |
||
108 | |||
109 | |||
110 | :param call: |
||
111 | The call where the method should be registered. |
||
112 | |||
113 | Keyword Arguments: |
||
114 | |||
115 | :param override: |
||
116 | Setting this to True will allow you to override existing |
||
117 | middleware. Usage of this is discouraged, but can help you out |
||
118 | of some situations. |
||
119 | """ |
||
120 | |||
121 | def decorator(func: Coro): |
||
122 | if override: |
||
123 | _log.warning( |
||
124 | f"Middleware overriding has been enabled for `{call}`." |
||
125 | " This might cause unexpected behaviour." |
||
126 | ) |
||
127 | |||
128 | if not override and callable(_events.get(call)): |
||
2 ignored issues
–
show
|
|||
129 | raise RuntimeError( |
||
130 | f"Middleware event with call `{call}` has " |
||
131 | "already been registered" |
||
132 | ) |
||
133 | |||
134 | async def wrapper(cls, payload: GatewayDispatch): |
||
135 | _log.debug("`%s` middleware has been invoked", call) |
||
136 | |||
137 | return await ( |
||
138 | func(cls, payload) |
||
139 | if should_pass_cls(func) |
||
140 | else await func(payload) |
||
141 | ) |
||
142 | |||
143 | _events[call] = wrapper |
||
144 | return wrapper |
||
145 | |||
146 | return decorator |
||
147 | |||
148 | |||
149 | class Client(Dispatcher): |
||
150 | def __init__(self, token: str): |
||
151 | """ |
||
2 ignored issues
–
show
|
|||
152 | The client is the main instance which is between the programmer |
||
153 | and the discord API. |
||
154 | |||
155 | This client represents your bot. |
||
156 | |||
157 | :param token: |
||
158 | The secret bot token which can be found in |
||
159 | `<https://discord.com/developers/applications/\<bot_id\>/bot>`_ |
||
160 | """ |
||
161 | # TODO: Implement intents |
||
162 | super().__init__( |
||
163 | token, |
||
164 | handlers={ |
||
165 | # Use this event handler for opcode 0. |
||
166 | 0: self.event_handler |
||
167 | } |
||
168 | ) |
||
169 | |||
170 | self.bot: Optional[User] = None |
||
171 | self.__token = token |
||
172 | |||
173 | @property |
||
174 | def http(self): |
||
175 | """ |
||
176 | Returns a http client with the current client its |
||
177 | authentication credentials. |
||
178 | |||
179 | :Usage example: |
||
180 | |||
181 | .. code-block:: pycon |
||
182 | |||
183 | >>> async with self.http as client: |
||
184 | >>> await client.post( |
||
185 | ... '<endpoint>', |
||
186 | ... { |
||
187 | ... "foo": "bar", |
||
188 | ... "bar": "baz", |
||
189 | ... "baz": "foo" |
||
190 | ... } |
||
191 | ... ) |
||
192 | |||
193 | """ |
||
194 | return HTTPClient(self.__token) |
||
195 | |||
196 | @staticmethod |
||
197 | def event(coroutine: Coro): |
||
198 | """ |
||
199 | Register a Discord gateway event listener. This event will get |
||
200 | called when the client receives a new event update from Discord |
||
201 | which matches the event name. |
||
202 | |||
203 | The event name gets pulled from your method name, and this must |
||
204 | start with ``on_``. This forces you to write clean and consistent |
||
205 | code. |
||
206 | |||
207 | This decorator can be used in and out of a class, and all |
||
208 | event methods must be coroutines. *(async)* |
||
209 | |||
210 | :Example usage: |
||
211 | |||
212 | .. code-block:: pycon |
||
213 | |||
214 | >>> # Function based |
||
215 | >>> from pincer import Client |
||
216 | >>> |
||
217 | >>> client = Client("token") |
||
218 | >>> |
||
219 | >>> @client.event |
||
220 | >>> async def on_ready(): |
||
221 | ... print(f"Signed in as {client.bot}") |
||
222 | >>> |
||
223 | >>> if __name__ == "__main__": |
||
224 | ... client.run() |
||
225 | |||
226 | .. code-block :: pycon |
||
227 | |||
228 | >>> # Class based |
||
229 | >>> from pincer import Client |
||
230 | >>> |
||
231 | >>> class BotClient(Client): |
||
232 | ... @Client.event |
||
233 | ... async def on_ready(self): |
||
234 | ... print(f"Signed in as {self.bot}") |
||
235 | >>> |
||
236 | >>> if __name__ == "__main__": |
||
237 | ... BotClient("token").run() |
||
238 | |||
239 | |||
240 | :param coroutine: # TODO: add info |
||
241 | |||
242 | :raises TypeError: |
||
243 | If the method is not a coroutine. |
||
244 | |||
245 | :raises InvalidEventName: |
||
246 | If the event name does not start with ``on_``, has already |
||
247 | been registered or is not a valid event name. |
||
248 | """ |
||
249 | |||
250 | if not iscoroutinefunction(coroutine): |
||
251 | raise TypeError( |
||
252 | "Any event which is registered must be a coroutine function" |
||
253 | ) |
||
254 | |||
255 | name: str = coroutine.__name__.lower() |
||
256 | |||
257 | if not name.startswith("on_"): |
||
258 | raise InvalidEventName( |
||
259 | f"The event `{name}` its name must start with `on_`" |
||
260 | ) |
||
261 | |||
262 | if _events.get(name) is not None: |
||
1 ignored issue
–
show
|
|||
263 | raise InvalidEventName( |
||
264 | f"The event `{name}` has already been registered or is not " |
||
265 | f"a valid event name." |
||
266 | ) |
||
267 | |||
268 | _events[name] = coroutine |
||
269 | return coroutine |
||
270 | |||
271 | async def handle_middleware( |
||
272 | self, |
||
273 | payload: GatewayDispatch, |
||
274 | key: str, |
||
275 | *args, |
||
276 | **kwargs |
||
277 | ) -> Tuple[Optional[Coro], List[Any], Dict[str, Any]]: |
||
278 | """ |
||
1 ignored issue
–
show
|
|||
279 | Handles all middleware recursively. Stops when it has found an |
||
280 | event name which starts with ``on_``. |
||
281 | |||
282 | :param payload: |
||
283 | The original payload for the event. |
||
284 | |||
285 | :param key: |
||
286 | The index of the middleware in ``_events``. |
||
287 | |||
288 | :param \*args: |
||
289 | The arguments which will be passed to the middleware. |
||
290 | |||
291 | :param \*\*kwargs: |
||
292 | The named arguments which will be passed to the middleware. |
||
293 | |||
294 | :return: |
||
295 | A tuple where the first element is the final executor |
||
296 | (so the event) its index in ``_events``. The second and third |
||
297 | element are the ``*args`` and ``**kwargs`` for the event. |
||
298 | """ |
||
299 | ware: middleware_type = _events.get(key) |
||
300 | next_call, arguments, params = ware, list(), dict() |
||
301 | |||
302 | if iscoroutinefunction(ware): |
||
303 | extractable = await ware(self, payload, *args, **kwargs) |
||
304 | |||
305 | if not isinstance(extractable, tuple): |
||
306 | raise RuntimeError( |
||
307 | f"Return type from `{key}` middleware must be tuple. " |
||
308 | ) |
||
309 | |||
310 | next_call = get_index(extractable, 0, "") |
||
311 | arguments = get_index(extractable, 1, list()) |
||
312 | params = get_index(extractable, 2, dict()) |
||
313 | |||
314 | if next_call is None: |
||
315 | raise RuntimeError(f"Middleware `{key}` has not been registered.") |
||
316 | |||
317 | return ( |
||
318 | (next_call, arguments, params) |
||
319 | if next_call.startswith("on_") |
||
320 | else await self.handle_middleware( |
||
321 | payload, next_call, *arguments, **params |
||
322 | ) |
||
323 | ) |
||
324 | |||
325 | async def event_handler(self, _, payload: GatewayDispatch): |
||
326 | """ |
||
327 | Handles all payload events with opcode 0. |
||
328 | |||
329 | :param _: |
||
330 | Socket param, but this isn't required for this handler. So |
||
331 | its just a filler parameter, doesn't matter what is passed. |
||
332 | |||
333 | :param payload: |
||
334 | The payload sent from the Discord gateway, this contains the |
||
335 | required data for the client to know what event it is and |
||
336 | what specifically happened. |
||
337 | """ |
||
338 | event_name = payload.event_name.lower() |
||
339 | |||
340 | key, args, kwargs = await self.handle_middleware(payload, event_name) |
||
341 | |||
342 | call = _events.get(key) |
||
343 | |||
344 | if iscoroutinefunction(call): |
||
345 | if should_pass_cls(call): |
||
346 | await call(self, *args, **kwargs) |
||
347 | else: |
||
348 | await call(*args, **kwargs) |
||
349 | |||
350 | @middleware("ready") |
||
351 | async def on_ready_middleware(self, payload: GatewayDispatch): |
||
352 | """ |
||
353 | Middleware for ``on_ready`` event. |
||
354 | |||
355 | :param payload: The data received from the ready event. |
||
356 | """ |
||
357 | self.bot = User.from_dict(payload.data.get("user")) |
||
358 | return "on_ready", |
||
1 ignored issue
–
show
|
|||
359 | |||
360 | |||
361 | Bot = Client |
||
362 |