Total Complexity | 19 |
Total Lines | 238 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # -*- coding: utf-8 -*- |
||
|
|||
2 | # MIT License |
||
3 | # |
||
4 | # Copyright (c) 2021 Pincer |
||
5 | # |
||
6 | # Permission is hereby granted, free of charge, to any person obtaining |
||
7 | # a copy of this software and associated documentation files |
||
8 | # (the "Software"), to deal in the Software without restriction, |
||
9 | # including without limitation the rights to use, copy, modify, merge, |
||
10 | # publish, distribute, sublicense, and/or sell copies of the Software, |
||
11 | # and to permit persons to whom the Software is furnished to do so, |
||
12 | # subject to the following conditions: |
||
13 | # |
||
14 | # The above copyright notice and this permission notice shall be |
||
15 | # included in all copies or substantial portions of the Software. |
||
16 | # |
||
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
||
18 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
||
19 | # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
||
20 | # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
||
21 | # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
||
22 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
||
23 | # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
||
24 | import logging |
||
25 | from asyncio import iscoroutinefunction |
||
26 | from typing import Optional, Any, Union, Dict, Tuple, List |
||
27 | |||
28 | from pincer import __package__ |
||
1 ignored issue
–
show
|
|||
29 | from pincer._config import GatewayConfig, events |
||
30 | from pincer.core.dispatch import GatewayDispatch |
||
31 | from pincer.core.gateway import Dispatcher |
||
32 | from pincer.core.http import HTTPClient |
||
33 | from pincer.exceptions import InvalidEventName |
||
34 | from pincer.objects.user import User |
||
35 | from pincer.utils.extraction import get_index |
||
36 | from pincer.utils.insertion import should_pass_cls |
||
37 | from pincer.utils.types import Coro |
||
38 | |||
39 | _log = logging.getLogger(__package__) |
||
40 | |||
41 | middleware_type = Optional[Union[Coro, Tuple[str, List[Any], Dict[str, Any]]]] |
||
1 ignored issue
–
show
|
|||
42 | |||
43 | _events: Dict[str, Optional[Union[str, Coro]]] = {} |
||
44 | |||
45 | for event in events: |
||
46 | event_final_executor = f"on_{event}" |
||
47 | |||
48 | # Event middleware for the library. Function argument is a payload |
||
49 | # (GatewayDispatch). The function must return a string which |
||
50 | # contains the main event key. As second value a list with arguments, |
||
51 | # and thee third value value must be a dictionary. The last two are |
||
52 | # passed on as *args and **kwargs. |
||
53 | # |
||
54 | # NOTE: These return values must be passed as a tuple! |
||
55 | _events[event] = event_final_executor |
||
56 | |||
57 | # The registered event by the client. Do not manually overwrite. |
||
58 | _events[event_final_executor] = None |
||
59 | |||
60 | |||
61 | def middleware(call: str, *, override: bool = False): |
||
62 | """ |
||
63 | Middleware are methods which can be registered with this decorator. |
||
64 | These methods are invoked before any `on_` event. As the `on_` event |
||
65 | is the final call. |
||
66 | |||
67 | A default call exists for all events, but some might already be in |
||
68 | use by the library. If you know what you are doing, you can override |
||
69 | these default middleware methods by passing the override parameter. |
||
70 | |||
71 | The method to which this decorator is registered must be a coroutine, |
||
72 | and it must return a tuple with the following format: |
||
73 | ``` |
||
74 | tuple( |
||
75 | key for next middleware or final event [str], |
||
76 | args for next middleware/event which will be passed as *args |
||
77 | [list(Any)], |
||
78 | kwargs for next middleware/event which will be passed as |
||
79 | **kwargs [dict(Any)] |
||
80 | ) |
||
81 | ``` |
||
82 | |||
83 | One parameter is passed to the middleware. This parameter is the |
||
84 | payload parameter which is of type GatewayDispatch. This contains |
||
85 | the response from the discord API. |
||
86 | |||
87 | Implementation example: |
||
88 | ```py |
||
89 | @middleware("ready", override=True) |
||
90 | async def custom_ready(payload: GatewayDispatch): |
||
91 | return "on_ready", [User.from_dict(payload.data.get("user"))] |
||
92 | |||
93 | @Client.event |
||
94 | async def on_ready(bot: User): |
||
95 | print(f"Signed in as {bot}") |
||
96 | ``` |
||
97 | |||
98 | :param call: The call where the method should be registered. |
||
99 | :param override: Setting this to True will allow you to override |
||
100 | existing middleware. Usage of this is discouraged, but can help |
||
101 | you out of some situations. |
||
102 | """ |
||
103 | def decorator(func: Coro): |
||
104 | if override: |
||
105 | _log.warning(f"Middleware overriding has been enabled for `{call}`." |
||
106 | f"This might cause unexpected behaviour.") |
||
107 | |||
108 | if not override and iscoroutinefunction(_events.get(call)): |
||
2 ignored issues
–
show
|
|||
109 | raise RuntimeError(f"Middleware event with call `{call}` has " |
||
110 | f"already been registered or is no coroutine.") |
||
111 | |||
112 | async def wrapper(cls, payload: GatewayDispatch): |
||
113 | _log.debug("`%s` middleware has been invoked", call) |
||
114 | return await func(cls, payload) \ |
||
115 | if should_pass_cls(func) \ |
||
116 | else await func(payload) |
||
117 | |||
118 | _events[call] = wrapper |
||
119 | return wrapper |
||
120 | |||
121 | return decorator |
||
122 | |||
123 | |||
124 | class Client(Dispatcher): |
||
125 | def __init__(self, token: str): |
||
126 | """ |
||
127 | The client is the main instance which is between the programmer and the |
||
128 | discord API. This client represents your bot. |
||
129 | |||
130 | :param token: The secret bot token which can be found in |
||
131 | https://discord.com/developers/applications/<bot_id>/bot. |
||
132 | """ |
||
133 | # TODO: Implement intents |
||
134 | super().__init__( |
||
135 | token, |
||
136 | handlers={ |
||
137 | # Use this event handler for opcode 0. |
||
138 | 0: self.event_handler |
||
139 | } |
||
140 | ) |
||
141 | |||
142 | # TODO: close the client after use |
||
143 | self.http = HTTPClient(token, version=GatewayConfig.version) |
||
144 | self.bot: Optional[User] = None |
||
145 | |||
146 | @staticmethod |
||
147 | def event(coroutine: Coro): |
||
148 | # TODO: Write docs |
||
149 | |||
150 | if not iscoroutinefunction(coroutine): |
||
151 | raise TypeError( |
||
152 | "Any event which is registered must be a coroutine function" |
||
153 | ) |
||
154 | |||
155 | name: str = coroutine.__name__.lower() |
||
156 | |||
157 | if not name.startswith("on_"): |
||
158 | raise InvalidEventName( |
||
159 | f"The event `{name}` its name must start with `on_`" |
||
160 | ) |
||
161 | |||
162 | if _events.get(name) is not None: |
||
1 ignored issue
–
show
|
|||
163 | raise InvalidEventName( |
||
164 | f"The event `{name}` has already been registered or is not " |
||
165 | f"a valid event name." |
||
166 | ) |
||
167 | |||
168 | _events[name] = coroutine |
||
169 | return coroutine |
||
170 | |||
171 | async def handle_middleware( |
||
172 | self, |
||
173 | payload: GatewayDispatch, |
||
174 | key: str, |
||
175 | *args, |
||
176 | **kwargs |
||
177 | ) -> tuple[Optional[Coro], List[Any], Dict[str, Any]]: |
||
178 | """ |
||
179 | Handles all middleware recursively. Stops when it has found an |
||
180 | event name which starts with "on_". |
||
181 | |||
182 | :param payload: The original payload for the event. |
||
183 | :param key: The index of the middleware in `_events`. |
||
184 | :param *args: The arguments which will be passed to the middleware. |
||
185 | :param **kwargs: The named arguments which will be passed to the |
||
186 | middleware. |
||
187 | |||
188 | :return: A tuple where the first element is the final executor |
||
189 | (so the event) its index in `_events`. The second and third |
||
190 | element are the `*args` and `**kwargs` for the event. |
||
191 | """ |
||
192 | ware: middleware_type = _events.get(key) |
||
193 | next_call, arguments, params = ware, list(), dict() |
||
194 | |||
195 | if iscoroutinefunction(ware): |
||
196 | extractable = await ware(self, payload, *args, **kwargs) |
||
197 | |||
198 | if not isinstance(extractable, tuple): |
||
199 | raise RuntimeError(f"Return type from `{key}` middleware must " |
||
200 | f"be tuple. ") |
||
201 | |||
202 | next_call = get_index(extractable, 0, "") |
||
203 | arguments = get_index(extractable, 1, list()) |
||
204 | params = get_index(extractable, 2, dict()) |
||
205 | |||
206 | if next_call is None: |
||
207 | raise RuntimeError(f"Middleware `{key}` has not been registered.") |
||
208 | |||
209 | return (next_call, arguments, params) \ |
||
210 | if next_call.startswith("on_") \ |
||
211 | else await self.handle_middleware(payload, next_call, |
||
212 | *arguments, **params) |
||
213 | |||
214 | async def event_handler(self, _, payload: GatewayDispatch): |
||
215 | """ |
||
216 | Handles all payload events with opcode 0. |
||
217 | """ |
||
218 | event_name = payload.event_name.lower() |
||
219 | |||
220 | key, args, kwargs = await self.handle_middleware(payload, event_name) |
||
221 | |||
222 | call = _events.get(key) |
||
223 | |||
224 | if iscoroutinefunction(call): |
||
225 | if should_pass_cls(call): |
||
226 | kwargs["self"] = self |
||
227 | |||
228 | await call(*args, **kwargs) |
||
229 | |||
230 | @middleware("ready") |
||
231 | async def on_ready_middleware(self, payload: GatewayDispatch): |
||
232 | """Middleware for `on_ready` event. """ |
||
233 | self.bot = User.from_dict(payload.data.get("user")) |
||
234 | return "on_ready", |
||
1 ignored issue
–
show
|
|||
235 | |||
236 | |||
237 | Bot = Client |
||
238 |