Passed
Push — main ( 6627c3...a229ba )
by
unknown
01:35
created

pincer.client.Client.get_guild()   A

Complexity

Conditions 1

Size

Total Lines 12
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 12
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
# Copyright Pincer 2021-Present
2
# Full MIT License can be found in `LICENSE` at the project root.
3
4
from __future__ import annotations
5
6
import logging
7
from asyncio import iscoroutinefunction, run
8
from inspect import isasyncgenfunction
9
from typing import Optional, Any, Union, Dict, Tuple, List
10
11
from . import __package__
12
from ._config import events
13
from .commands import ChatCommandHandler
14
from .core.dispatch import GatewayDispatch
15
from .core.gateway import Dispatcher
16
from .core.http import HTTPClient
17
from .exceptions import InvalidEventName
18
from .middleware import middleware
19
from .objects import User, Intents, Guild, ThrottleInterface
20
from .objects.throttling import DefaultThrottleHandler
21
from .utils import get_index, should_pass_cls, Coro
22
23
_log = logging.getLogger(__package__)
24
25
MiddlewareType = Optional[Union[Coro, Tuple[str, List[Any], Dict[str, Any]]]]
26
27
_events: Dict[str, Optional[Union[str, Coro]]] = {}
28
29
for event in events:
30
    event_final_executor = f"on_{event}"
31
32
    # Event middleware for the library.
33
    # Function argument is a payload (GatewayDispatch).
34
35
    # The function must return a string which
36
    # contains the main event key.
37
38
    # As second value a list with arguments,
39
    # and the third value value must be a dictionary.
40
    # The last two are passed on as *args and **kwargs.
41
42
    # NOTE: These return values must be passed as a tuple!
43
    _events[event] = event_final_executor
44
45
    # The registered event by the client. Do not manually overwrite.
46
    _events[event_final_executor] = None
47
48
49
def event_middleware(call: str, *, override: bool = False):
50
    """
51
    Middleware are methods which can be registered with this decorator.
52
    These methods are invoked before any ``on_`` event.
53
    As the ``on_`` event is the final call.
54
55
    A default call exists for all events, but some might already be in
56
    use by the library.
57
58
    If you know what you are doing, you can override these default
59
    middleware methods by passing the override parameter.
60
61
    The method to which this decorator is registered must be a coroutine,
62
    and it must return a tuple with the following format:
63
64
    .. code-block:: python
65
66
        tuple(
67
            key for next middleware or final event [str],
68
            args for next middleware/event which will be passed as *args
69
                [list(Any)],
70
            kwargs for next middleware/event which will be passed as
71
                **kwargs [dict(Any)]
72
        )
73
74
    Two parameters are passed to the middleware. The first parameter is
75
    the current socket connection with and the second one is the payload
76
    parameter which is of type :class:`~.core.dispatch.GatewayDispatch`.
77
    This contains the response from the discord API.
78
79
    :Implementation example:
80
81
    .. code-block:: pycon
82
83
        >>> @event_middleware("ready", override=True)
84
        >>> async def custom_ready(_, payload: GatewayDispatch):
85
        >>>     return "on_ready", [
86
        >>>         User.from_dict(payload.data.get("user"))
87
        >>>     ]
88
89
        >>> @Client.event
90
        >>> async def on_ready(bot: User):
91
        >>>     print(f"Signed in as {bot}")
92
93
94
    :param call:
95
        The call where the method should be registered.
96
97
    Keyword Arguments:
98
99
    :param override:
100
        Setting this to True will allow you to override existing
101
        middleware. Usage of this is discouraged, but can help you out
102
        of some situations.
103
    """
104
105
    def decorator(func: Coro):
106
        if override:
107
            _log.warning(
108
                "Middleware overriding has been enabled for `%s`."
109
                " This might cause unexpected behavior.", call
110
            )
111
112
        if not override and callable(_events.get(call)):
113
            raise RuntimeError(
114
                f"Middleware event with call `{call}` has "
115
                "already been registered"
116
            )
117
118
        async def wrapper(cls, payload: GatewayDispatch):
119
            _log.debug("`%s` middleware has been invoked", call)
120
121
            return await (
122
                func(cls, payload)
123
                if should_pass_cls(func)
124
                else func(payload)
125
            )
126
127
        _events[call] = wrapper
128
        return wrapper
129
130
    return decorator
131
132
133
for event, middleware in middleware.items():
134
    event_middleware(event)(middleware)
135
136
137
class Client(Dispatcher):
138
    def __init__(
139
            self,
140
            token: str, *,
141
            received: str = None,
142
            intents: Intents = None,
143
            throttler: ThrottleInterface = DefaultThrottleHandler
144
    ):
145
        """
146
        The client is the main instance which is between the programmer
147
            and the discord API.
148
149
        This client represents your bot.
150
151
        :param token:
152
            The secret bot token which can be found in
153
            `<https://discord.com/developers/applications/<bot_id>/bot>`_
154
155
        :param received:
156
            The default message which will be sent when no response is
157
            given.
158
159
        :param intents:
160
            The discord intents for your client.
161
        """
162
        super().__init__(
163
            token,
164
            handlers={
165
                # Gets triggered on all events
166
                -1: self.payload_event_handler,
167
                # Use this event handler for opcode 0.
168
                0: self.event_handler
169
            },
170
            intents=intents or Intents.NONE
171
        )
172
173
        self.bot: Optional[User] = None
174
        self.received_message = received or "Command arrived successfully!"
175
        self.http = HTTPClient(token)
176
        self.throttler = throttler
177
178
    @property
179
    def chat_commands(self):
180
        """
181
        Get a list of chat command calls which have been registered in
182
        the ChatCommandHandler.
183
        """
184
        return [cmd.app.name for cmd in ChatCommandHandler.register.values()]
185
186
    @staticmethod
187
    def event(coroutine: Coro):
188
        """
189
        Register a Discord gateway event listener. This event will get
190
        called when the client receives a new event update from Discord
191
        which matches the event name.
192
193
        The event name gets pulled from your method name, and this must
194
        start with ``on_``.
195
        This forces you to write clean and consistent code.
196
197
        This decorator can be used in and out of a class, and all
198
        event methods must be coroutines. *(async)*
199
200
        :Example usage:
201
202
        .. code-block:: pycon
203
204
            >>> # Function based
205
            >>> from pincer import Client
206
            >>>
207
            >>> client = Client("token")
208
            >>>
209
            >>> @client.event
210
            >>> async def on_ready():
211
            ...     print(f"Signed in as {client.bot}")
212
            >>>
213
            >>> if __name__ == "__main__":
214
            ...     client.run()
215
216
        .. code-block :: pycon
217
218
            >>> # Class based
219
            >>> from pincer import Client
220
            >>>
221
            >>> class BotClient(Client):
222
            ...     @Client.event
223
            ...     async def on_ready(self):
224
            ...         print(f"Signed in as {self.bot}")
225
            >>>
226
            >>> if __name__ == "__main__":
227
            ...     BotClient("token").run()
228
229
230
        :param coroutine: # TODO: add info
231
232
        :raises TypeError:
233
            If the method is not a coroutine.
234
235
        :raises InvalidEventName:
236
            If the event name does not start with ``on_``, has already
237
            been registered or is not a valid event name.
238
        """
239
240
        if not iscoroutinefunction(coroutine) \
241
                and not isasyncgenfunction(coroutine):
242
            raise TypeError(
243
                "Any event which is registered must be a coroutine function"
244
            )
245
246
        name: str = coroutine.__name__.lower()
247
248
        if not name.startswith("on_"):
249
            raise InvalidEventName(
250
                f"The event named `{name}` must start with `on_`"
251
            )
252
253
        if _events.get(name) is not None:
254
            raise InvalidEventName(
255
                f"The event `{name}` has already been registered or is not "
256
                f"a valid event name."
257
            )
258
259
        _events[name] = coroutine
260
        return coroutine
261
262
    @staticmethod
263
    def get_event_coro(name: str) -> Optional[Coro]:
264
        call = _events.get(name.strip().lower())
265
        if iscoroutinefunction(call) or isasyncgenfunction(call):
266
            return call
267
268
    def run(self):
269
        """Start the event listener"""
270
        self.start_loop()
271
        run(self.http.close())
272
273
    async def handle_middleware(
274
            self,
275
            payload: GatewayDispatch,
276
            key: str,
277
            *args,
278
            **kwargs
279
    ) -> Tuple[Optional[Coro], List[Any], Dict[str, Any]]:
280
        """
281
        Handles all middleware recursively. Stops when it has found an
282
        event name which starts with ``on_``.
283
284
        :param payload:
285
            The original payload for the event.
286
287
        :param key:
288
            The index of the middleware in ``_events``.
289
290
        :param \\*args:
291
            The arguments which will be passed to the middleware.
292
293
        :param \\*\\*kwargs:
294
            The named arguments which will be passed to the middleware.
295
296
        :return:
297
            A tuple where the first element is the final executor
298
            (so the event) its index in ``_events``.
299
300
            The second and third element are the ``*args``
301
            and ``**kwargs`` for the event.
302
        """
303
        ware: MiddlewareType = _events.get(key)
304
        next_call, arguments, params = ware, [], {}
305
306
        if iscoroutinefunction(ware):
307
            extractable = await ware(self, payload, *args, **kwargs)
308
309
            if not isinstance(extractable, tuple):
310
                raise RuntimeError(
311
                    f"Return type from `{key}` middleware must be tuple. "
312
                )
313
314
            next_call = get_index(extractable, 0, "")
315
            arguments = get_index(extractable, 1, [])
316
            params = get_index(extractable, 2, {})
317
318
        if next_call is None:
319
            raise RuntimeError(f"Middleware `{key}` has not been registered.")
320
321
        return (
322
            (next_call, arguments, params)
323
            if next_call.startswith("on_")
324
            else await self.handle_middleware(
325
                payload, next_call, *arguments, **params
326
            )
327
        )
328
329
    async def execute_error(
330
            self,
331
            error: Exception,
332
            name: str = "on_error",
333
            *args,
334
            **kwargs
335
    ):
336
        """
337
        Raises an error if no appropriate error event has been found.
338
339
        :param error:
340
            The error which should be raised or passed to the event.
341
342
        :param name:
343
            The name of the event, and how it is registered by the client.
344
345
        :param \\*args:
346
            The arguments for the event.
347
348
        :param \\*kwargs:
349
            The named arguments for the event.
350
        """
351
        if call := self.get_event_coro(name):
0 ignored issues
show
introduced by
invalid syntax (<unknown>, line 351)
Loading history...
352
            await self.execute_event(call, error, *args, **kwargs)
353
        else:
354
            raise error
355
356
    async def execute_event(self, call: Coro, *args, **kwargs):
357
        """
358
        Invokes an event.
359
360
        :param call:
361
            The call (method) to which the event is registered.
362
363
        :param \\*args:
364
            The arguments for the event.
365
366
        :param \\*kwargs:
367
            The named arguments for the event.
368
        """
369
370
        if should_pass_cls(call):
371
            await call(self, *args, **kwargs)
372
        else:
373
            await call(*args, **kwargs)
374
375
    async def process_event(self, name: str, payload: GatewayDispatch):
376
        """
377
        Processes and invokes an event and its middleware.
378
379
        :param name:
380
            The name of the event, this is also the filename in the
381
            middleware directory.
382
383
        :param payload:
384
            The payload sent from the Discord gateway, this contains the
385
            required data for the client to know what event it is and
386
            what specifically happened.
387
        """
388
        try:
389
            key, args, kwargs = await self.handle_middleware(payload, name)
390
391
            if call := self.get_event_coro(key):
392
                await self.execute_event(call, *args, **kwargs)
393
394
        except Exception as e:
395
            await self.execute_error(e)
396
397
    async def event_handler(self, _, payload: GatewayDispatch):
398
        """
399
        Handles all payload events with opcode 0.
400
401
        :param _:
402
            Socket param, but this isn't required for this handler. So
403
            its just a filler parameter, doesn't matter what is passed.
404
405
        :param payload:
406
            The payload sent from the Discord gateway, this contains the
407
            required data for the client to know what event it is and
408
            what specifically happened.
409
        """
410
        await self.process_event(payload.event_name.lower(), payload)
411
412
    async def payload_event_handler(self, _, payload: GatewayDispatch):
413
        """
414
        Special event which activates on_payload event!
415
416
        :param _:
417
            Socket param, but this isn't required for this handler. So
418
            its just a filler parameter, doesn't matter what is passed.
419
420
        :param payload:
421
            The payload sent from the Discord gateway, this contains the
422
            required data for the client to know what event it is and
423
            what specifically happened.
424
        """
425
        await self.process_event("payload", payload)
426
427
    async def get_guild(self, guild_id: int) -> Guild:
428
        """
429
        Fetch a guild object by the guild identifier.
430
431
        :param guild_id:
432
            The id of the guild which should be fetched from the Discord
433
            gateway.
434
435
        :returns:
436
            A Guild object.
437
        """
438
        return await Guild.from_id(self, guild_id)
439
440
    async def get_user(self, _id: int) -> User:
441
        """
442
        Fetch a User from its identifier
443
444
        :param _id:
445
            The id of the user which should be fetched from the Discord
446
            gateway.
447
448
        :returns:
449
            A User object.
450
        """
451
        return await User.from_id(self, _id)
452
453
454
Bot = Client
455