Completed
Push — main ( 6a4b2b...72993c )
by
unknown
01:58 queued 01:57
created

pincer.core.gateway.Gateway.__del__()   A

Complexity

Conditions 3

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 3
nop 1
1
# Copyright Pincer 2021-Present
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
# Full MIT License can be found in `LICENSE` at the project root.
3
4
5
from __future__ import annotations
6
7
from asyncio import create_task, Task, ensure_future, sleep
8
from dataclasses import dataclass
9
from datetime import datetime
10
from itertools import repeat, count, chain
11
import logging
12
from platform import system
13
from random import random
14
from typing import TYPE_CHECKING, Any, Dict, Callable, Optional
15
from zlib import decompressobj
16
17
from aiohttp import (
18
    ClientSession, WSMsgType, ClientConnectorError, ClientWebSocketResponse
19
)
20
21
from . import __package__
22
from ..utils.api_object import APIObject
23
from .._config import GatewayConfig
24
from ..core.dispatch import GatewayDispatch
25
from ..exceptions import (
26
    InvalidTokenError, GatewayConnectionError, GatewayError, UnhandledException
27
)
28
29
if TYPE_CHECKING:
30
    from ..objects.app.intents import Intents
31
    Handler = Callable[[GatewayDispatch], None]
32
33
_log = logging.getLogger(__package__)
34
35
ZLIB_SUFFIX = b'\x00\x00\xff\xff'
36
inflator = decompressobj()
37
38
39
@dataclass
40
class SessionStartLimit(APIObject):
41
    """Session start limit info returned from the `gateway/bot` endpoint"""
42
    total: int
43
    remaining: int
44
    reset_after: int
45
    max_concurrency: int
46
47
48
@dataclass
49
class GatewayInfo(APIObject):
50
    """Gateway info returned from the `gateway/bot` endpoint"""
51
    url: str
52
    shards: int
53
    session_start_limit: SessionStartLimit
54
55
56
class Gateway:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (18/7)
Loading history...
57
    """The Gateway handles all interactions with the Discord Websocket API.
58
    This also contains the main event loop, and handles the heartbeat.
59
60
    Running the Gateway will create a connection with the
61
    Discord Websocket API on behalf of the provided token.
62
63
    This token must be a bot token.
64
    (Which can be found on
65
    `<https://discord.com/developers/applications/>`_)
66
67
    Parameters
68
    ----------
69
    token : str.
70
        The token for this bot
71
    intents : :class:`~pincer.objects.app.intents.Intents`
72
        The itents to use. More information can be found at
73
        `<https://discord.com/developers/docs/topics/gateway#gateway-intents>`_.
74
    url : str
75
        The gateway url.
76
    shard : int
77
        The ID of the shard to run.
78
    num_shards : int
79
        Number used to route traffic to the current. This should usually be the total
80
        number of shards that will be run. More information at
81
        `<https://discord.com/developers/docs/topics/gateway#sharding>`_.
82
    """
83
    def __init__(
84
        self,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
85
        token: str, *,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
86
        intents: Intents,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
87
        url: str,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
88
        shard: int,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
89
        num_shards: int
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
90
    ) -> None:
91
        if len(token) != 59:
92
            raise InvalidTokenError(
93
                "Discord Token must have exactly 59 characters."
94
            )
95
96
        self.token = token
97
        self.intents = intents
98
        self.url = url
99
        self.shard = shard
100
        self.num_shards = num_shards
101
        self.shard_key = [shard, num_shards]
102
103
        self.__dispatch_handlers: Dict[int, Handler] = {
104
            1: self.handle_heartbeat_req,
105
            7: self.handle_reconnect,
106
            9: self.handle_invalid_session,
107
            10: self.identify_and_handle_hello,
108
            11: self.handle_heartbeat
109
        }
110
111
        # 4000 and 4009 are not included. The client will reconnect when receiving
112
        # either.
113
        self.__close_codes: Dict[int, GatewayError] = {
114
            4001: GatewayError("Invalid opcode was sent"),
115
            4002: GatewayError("Invalid payload was sent."),
116
            4003: GatewayError("Payload was sent prior to identifying"),
117
            4004: GatewayError("Token is not valid"),
118
            4005: GatewayError(
119
                "Authentication was sent after client already authenticated"
120
            ),
121
            4007: GatewayError("Invalid sequence sent when starting new session"),
122
            4008: GatewayError("Client was rate limited"),
123
            4010: GatewayError("Invalid shard"),
124
            4011: GatewayError("Sharding required"),
125
            4012: GatewayError("Invalid API version"),
126
            4013: GatewayError("Invalid intents"),
127
            4014: GatewayError("Disallowed intents")
128
        }
129
130
        # ClientSession to be used for this Dispatcher
131
        self.__session: Optional[ClientSession] = None
132
133
        # This type `_WSRequestContextManager` isn't exposed by aiohttp.
134
        # `ClientWebSocketResponse` is a parent class.
135
        self.__socket: Optional[ClientWebSocketResponse] = None
136
137
        # Buffer used to store information in transport conpression.
138
        self.__buffer = bytearray()
139
140
        # The gateway can be disconnected from Discord. This variable stores if the
141
        # gateway should send a hello or reconnect.
142
        self.__should_reconnect: bool = False
143
144
        # The sequence number for the last received payload. This is used reconnecting.
145
        self.__sequence_number: int = 0
146
147
        # The heartbeat task
148
        self.__heartbeat_task: Optional[Task] = None
149
150
        # Keeps the Client waiting until the next heartbeat
151
        self.__wait_for_heartbeat: Optional[Task] = None
152
153
        # How long the client should wait between each Heartbeat.
154
        self.__heartbeat_interval: Optional[int] = None
155
156
        # Tracks whether the gateway has received an ack (opcode 11) since the last
157
        # heartbeat.
158
        #   True: An ack has been received
159
        #   False: No ack has been received. Attempt to reconnect with gateway,
160
        self.__has_received_ack: bool = True
161
162
        # Session ID received from `on_ready` event. It is set in the `on_ready`
163
        # middleware. This is used reconnecting.
164
        self.__session_id: Optional[str] = None
165
166
    def __del__(self):
167
        """Delete method ensures all connections are closed"""
168
        if self.__socket:
169
            create_task(self.__socket.close())
170
        if self.__session:
171
            create_task(self.__session.close())
172
173
    async def init_session(self):
174
        """|coro|
175
        Crates the ClientSession. ALWAYS run this function right after initializing
176
        a Gateway.
177
        """
178
        self.__session = ClientSession()
179
180
    def append_handlers(self, handlers: Dict[int, Handler]):
181
        """The Client that uses the handler can append their own methods. The gateway
182
        will run those methods when the specified opcode is received.
183
        """
184
        self.__dispatch_handlers |= handlers
185
186
    def set_session_id(self, _id: str):
187
        """Session id is private for consistency"""
188
        self.__session_id = _id
189
190
    def decompress_msg(self, msg: bytes) -> Optional[str]:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
191
        if GatewayConfig.compression == "zlib-payload":
192
            return inflator.decompress(msg)
193
194
        if GatewayConfig.compression == "zlib-stream":
195
            self.__buffer.extend(msg)
196
197
            if len(self.__buffer) < 4 or self.__buffer[-4:] != ZLIB_SUFFIX:
198
                return None
199
200
            msg = inflator.decompress(msg)
201
            self.__buffer = bytearray()
202
            return msg
203
204
        return None
205
206
    async def start_loop(self):
207
        """|coro|
208
        Instantiate the dispatcher, this will create a connection to the
209
        Discord websocket API on behalf of the client whose token has
210
        been passed.
211
        """
212
        for _try in count():
213
            try:
214
                self.__socket = await self.__session.ws_connect(
215
                    GatewayConfig.make_uri(self.url)
216
                )
217
                break
218
            except ClientConnectorError as e:
219
                if _try > GatewayConfig.MAX_RETRIES:
220
                    raise GatewayConnectionError from e
221
222
                _log.warning(
223
                    "%s Could not open websocket with Discord."
224
                    " Retrying in 15 seconds...",
225
                    self.shard_key
226
                )
227
                await sleep(15)
228
229
        _log.debug("%s Starting envent loop...", self.shard_key)
230
        await self.event_loop()
231
232
    async def event_loop(self):
233
        """|coro|
234
        Handles receiving messages and decompressing them if needed
235
        """
236
        async for msg in self.__socket:
237
            if msg.type == WSMsgType.TEXT:
238
                await self.handle_data(msg.data)
239
            elif msg.type == WSMsgType.BINARY:
240
                # Message from transport compression that isn't complete returns None
241
                data = self.decompress_msg(msg.data)
242
                if data:
243
                    await self.handle_data(data)
244
            elif msg.type == WSMsgType.ERROR:
245
                raise GatewayError from self.__socket.exception()
246
247
        # The loop is broken when the gateway stops receiving messages.
248
        # The "error" op codes are in `self.__close_codes`. The rest of the
249
        # close codes are unknown issues (such as a unintended disconnect) so the
250
        # client should reconnect to the gateway.
251
        err = self.__close_codes.get(self.__socket.close_code)
252
253
        if err:
254
            raise err
255
256
        _log.debug(
257
            "%s Disconnected from Gateway due without any errors. Reconnecting.",
258
            self.shard_key
259
        )
260
        self.__should_reconnect = True
261
        self.start_loop()
262
263
    async def handle_data(self, data: Dict[Any]):
264
        """|coro|
265
        Method is run when a payload is received from the gateway.
266
        The message is expected to already have been decompressed.
267
        Handling the opcode is forked to the background so they aren't blocking.
268
        """
269
        payload = GatewayDispatch.from_string(data)
270
271
        # Op code -1 is activated on all payloads
272
        op_negative_one = self.__dispatch_handlers.get(-1)
273
        if op_negative_one:
274
            ensure_future(op_negative_one(payload))
275
276
        _log.debug(
277
            "%s %s GatewayDispatch with opcode %s received",
278
            self.shard_key,
279
            datetime.now(),
280
            payload.op
281
        )
282
283
        # Many events are sent with a `null` sequence. This sequence should not
284
        # be tracked.
285
        if payload.seq is not None:
286
            self.__sequence_number = payload.seq
287
            _log.debug("%s Set sequence number to %s", self.shard_key, payload.seq)
288
289
        handler = self.__dispatch_handlers.get(payload.op)
290
291
        if handler is None:
292
            raise UnhandledException(f"Opcode {payload.op} does not have a handler")
293
294
        ensure_future(handler(payload))
295
296
    async def handle_heartbeat_req(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
297
        """|coro|
298
        Opcode 1 - Instantly send a heartbeat.
299
        """
300
        self.send_next_heartbeat()
301
302
    async def handle_reconnect(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
303
        """|coro|
304
        Opcode 7 - Reconnect and resume immediately.
305
        """
306
        _log.debug(
307
            "%s Requested to reconnect to Discord. Closing session and attempting to"
308
            " resume...",
309
            self.shard_key
310
        )
311
312
        await self.__socket.close(code=1000)
313
        self.__should_reconnect = True
314
        await self.start_loop()
315
316
    async def handle_invalid_session(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
317
        """|coro|
318
        Opcode 9 - Invalid connection
319
        Attempt to relog. This is probably because the session was already invalidated
320
        when we tried to reconnect.
321
        """
322
        _log.debug("%s Invalid session, attempting to relog...", self.shard_key)
323
        self.__should_reconnect = False
324
        await self.start_loop()
325
326
    async def identify_and_handle_hello(self, payload: GatewayDispatch):
327
        """|coro|
328
        Opcode 10 - Hello there general kenobi
329
        Runs when we connect to the gateway for the first time and every time after.
330
        If the client thinks it should reconnect, the opcode 6 resume payload is sent
331
        instead of the opcode 2 hello payload. A new session is only started after a
332
        reconnect if pcode 9 is received.
333
334
        Successful reconnects are handled in the `resumed` middleware.
335
        """
336
        if self.__should_reconnect:
337
            _log.debug("%s Resuming connection with Discord", self.shard_key)
338
339
            await self.send(str(GatewayDispatch(
340
                6, {
341
                    "token": self.token,
342
                    "session_id": self.__session_id,
343
                    "seq": self.__sequence_number
344
                }
345
            )))
346
            return
347
348
        await self.send(str(
349
            GatewayDispatch(
350
                2, {
351
                    "token": self.token,
352
                    "intents": self.intents,
353
                    "properties": {
354
                        "$os": system(),
355
                        "$browser": __package__,
356
                        "$device": __package__
357
                    },
358
                    "compress": GatewayConfig.compressed(),
359
                    "shard": self.shard_key
360
                }
361
            )
362
        ))
363
        self.__heartbeat_interval = payload.data["heartbeat_interval"]
364
365
        # This process should already be forked to the background so there is no need to
366
        # `ensure_future()` here.
367
        self.start_heartbeat()
368
369
    async def handle_heartbeat(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
370
        """|coro|
371
        Opcode 11 - Heatbeat
372
        Track that the heartbeat has been received using shared state (Rustaceans would
373
        be very mad)
374
        """
375
        self.__has_received_ack = True
376
377
    async def send(self, payload: str):
378
        """|coro|
379
        Send a string object to the payload. Most of this method is just logging,
380
        the last line is the only one that matters for functionality.
381
        """
382
        safe_payload = payload.replace(self.token, "%s..." % self.token[:10])
383
384
        if self.__session_id:
385
            safe_payload = safe_payload.replace(
386
                self.__session_id, "%s..." % self.__session_id[:4]
387
            )
388
389
        _log.debug(
390
            "%s Sending payload: %s",
391
            self.shard_key,
392
            safe_payload
393
        )
394
395
        if self.__socket.closed:
396
            _log.debug(
397
                "%s Socket is closing. Payload not sent.",
398
                self.shard_key
399
            )
400
            return
401
402
        await self.__socket.send_str(payload)
403
404
    def start_heartbeat(self):
405
        """
406
        Starts the heartbeat if it is not already running.
407
        """
408
        if not self.__heartbeat_task or self.__heartbeat_task.cancelled():
409
            self.__heartbeat_task = ensure_future(self.__heartbeat_loop())
410
411
    def stop_heartbeat(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
412
        self.__heartbeat_task.cancel()
413
414
    def send_next_heartbeat(self):
415
        """
416
        It is expected to always be waiting for a hearbeat. By canceling that task,
417
        a heartbeat can be sent.
418
        """
419
        self.__wait_for_heartbeat.cancel()
420
421
    async def __heartbeat_loop(self):
422
        """|coro|
423
        The heartbeat is responsible for keeping the connection to Discord alive.
424
425
        Jitter is only random for the first heartbeat. It should be 1 every other
426
        heartbeat.
427
        """
428
        _log.debug("%s Starting heartbeat loop...", self.shard_key)
429
430
        # When waiting for first heartbeat, there hasn't been an ack received yet.
431
        # Set to true so the ack received check doesn't incorrectly fail.
432
        self.__has_received_ack = True
433
434
        for jitter in chain((random(),), repeat(1)):
435
            duration = self.__heartbeat_interval * jitter
436
437
            _log.debug(
438
                "%s %s sending heartbeat in %sms",
439
                self.shard_key, datetime.now(),
440
                duration
441
            )
442
443
            # Task is needed so waiting can be cancelled by op code 1
444
            self.__wait_for_heartbeat = create_task(
445
                sleep(duration / 1000)
446
            )
447
448
            await self.__wait_for_heartbeat
449
450
            if not self.__has_received_ack:
451
                # Close code is specified to be anything that is not 1000 in the docs.
452
                _log.debug(
453
                    "%s %s ack not received. Attempting to reconnect."
454
                    " Closing socket with close code 1001. \U0001f480",
455
                    datetime.now(),
456
                    self.shard_key
457
                )
458
                await self.__socket.close(code=1001)
459
                self.__should_reconnect = True
460
                # A new loop is started in the background while this one is stopped.
461
                ensure_future(self.start_loop())
462
                self.stop_heartbeat()
463
                return
464
465
            self.__has_received_ack = False
466
            await self.send(str(GatewayDispatch(1, data=self.__sequence_number)))
467
            _log.debug("%s sent heartbeat", self.shard_key)
468