Passed
Pull Request — main (#329)
by
unknown
02:04
created

pincer.core.gateway.Gateway.__del__()   A

Complexity

Conditions 3

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 3
nop 1
1
# Copyright Pincer 2021-Present
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
# Full MIT License can be found in `LICENSE` at the project root.
3
4
5
from __future__ import annotations
6
7
from asyncio import create_task, Task, ensure_future, sleep
8
from dataclasses import dataclass
9
from datetime import datetime
10
from itertools import repeat, count, chain
11
import logging
12
from platform import system
13
from random import random
14
from typing import TYPE_CHECKING, Any, Dict, Callable, Optional
15
from zlib import decompressobj
16
17
from aiohttp import (
18
    ClientSession, WSMsgType, ClientConnectorError, ClientWebSocketResponse
19
)
20
21
from . import __package__
22
from ..utils.api_object import APIObject
23
from .._config import GatewayConfig
24
from ..core.dispatch import GatewayDispatch
25
from ..exceptions import (
26
    InvalidTokenError, GatewayConnectionError, GatewayError, UnhandledException
27
)
28
29
if TYPE_CHECKING:
30
    from ..objects.app.intents import Intents
31
    Handler = Callable[[GatewayDispatch], None]
32
33
_log = logging.getLogger(__package__)
34
35
ZLIB_SUFFIX = b'\x00\x00\xff\xff'
36
inflator = decompressobj()
37
38
39
@dataclass
40
class SessionStartLimit(APIObject):
41
    """Session start limit info returned from the `gateway/bot` endpoint"""
42
    total: int
43
    remaining: int
44
    reset_after: int
45
    max_concurrency: int
46
47
48
@dataclass
49
class GatewayInfo(APIObject):
50
    """Gateway info returned from the `gateway/bot` endpoint"""
51
    url: str
52
    shards: int
53
    session_start_limit: SessionStartLimit
54
55
56
class Gateway:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (18/7)
Loading history...
57
    """The Gateway handles all interactions with the Discord Websocket API.
58
    This also contains the main event loop, and handles the heartbeat.
59
60
    Running the Gateway will create a connection with the
61
    Discord Websocket API on behalf of the provided token.
62
63
    This token must be a bot token.
64
    (Which can be found on
65
    `<https://discord.com/developers/applications/>`_)
66
    """
67
68
    def __init__(
69
        self,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
70
        token: str, *,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
71
        intents: Intents,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
72
        url: str,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
73
        shard: int,
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
74
        num_shards: int
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block (add 4 spaces).
Loading history...
75
    ) -> None:
76
        if len(token) != 59:
77
            raise InvalidTokenError(
78
                "Discord Token must have exactly 59 characters."
79
            )
80
81
        self.token = token
82
        self.intents = intents
83
        self.url = url
84
        self.shard = shard
85
        self.num_shards = num_shards
86
        self.shard_key = [shard, num_shards]
87
88
        self.__dispatch_handlers: Dict[int, Handler] = {
89
            1: self.handle_heartbeat_req,
90
            7: self.handle_reconnect,
91
            9: self.handle_invalid_session,
92
            10: self.identify_and_handle_hello,
93
            11: self.handle_heartbeat
94
        }
95
96
        # 4000 and 4009 are not included. The client will reconnect when receiving
97
        # either.
98
        self.__close_codes: Dict[int, GatewayError] = {
99
            4001: GatewayError("Invalid opcode was sent"),
100
            4002: GatewayError("Invalid payload was sent."),
101
            4003: GatewayError("Payload was sent prior to identifying"),
102
            4004: GatewayError("Token is not valid"),
103
            4005: GatewayError(
104
                "Authentication was sent after client already authenticated"
105
            ),
106
            4007: GatewayError("Invalid sequence sent when starting new session"),
107
            4008: GatewayError("Client was rate limited"),
108
            4010: GatewayError("Invalid shard"),
109
            4011: GatewayError("Sharding required"),
110
            4012: GatewayError("Invalid API version"),
111
            4013: GatewayError("Invalid intents"),
112
            4014: GatewayError("Disallowed intents")
113
        }
114
115
        # ClientSession to be used for this Dispatcher
116
        self.__session: Optional[ClientSession] = None
117
118
        # This type `_WSRequestContextManager` isn't exposed by aiohttp.
119
        # `ClientWebSocketResponse` is a parent class.
120
        self.__socket: Optional[ClientWebSocketResponse] = None
121
122
        # Buffer used to store information in transport conpression.
123
        self.__buffer = bytearray()
124
125
        # The gateway can be disconnected from Discord. This variable stores if the
126
        # gateway should send a hello or reconnect.
127
        self.__should_reconnect: bool = False
128
129
        # The sequence number for the last received payload. This is used reconnecting.
130
        self.__sequence_number: int = 0
131
132
        # The heartbeat task
133
        self.__heartbeat_task: Optional[Task] = None
134
135
        # Keeps the Client waiting until the next heartbeat
136
        self.__wait_for_heartbeat: Optional[Task] = None
137
138
        # How long the client should wait between each Heartbeat.
139
        self.__heartbeat_interval: Optional[int] = None
140
141
        # Tracks whether the gateway has received an ack (opcode 11) since the last
142
        # heartbeat.
143
        #   True: An ack has been received
144
        #   False: No ack has been received. Attempt to reconnect with gateway,
145
        self.__has_received_ack: bool = True
146
147
        # Session ID received from `on_ready` event. It is set in the `on_ready`
148
        # middleware. This is used reconnecting.
149
        self.__session_id: Optional[str] = None
150
151
    def __del__(self):
152
        """Delete method ensures all connections are closed"""
153
        if self.__socket:
154
            create_task(self.__socket.close())
155
        if self.__session:
156
            create_task(self.__session.close())
157
158
    async def init_session(self):
159
        """
160
        Crates the ClientSession. ALWAYS run this function right after initializing
161
        a Gateway.
162
        """
163
        self.__session = ClientSession()
164
165
    def append_handlers(self, handlers: Dict[int, Handler]):
166
        """The Client that uses the handler can append their own methods. The gateway
167
        will run those methods when the specified opcode is received.
168
        """
169
        self.__dispatch_handlers = handlers | self.__dispatch_handlers
170
171
    def set_session_id(self, _id: str):
172
        """Session id is private for consistency"""
173
        self.__session_id = _id
174
175
    def decompress_msg(self, msg: bytes) -> Optional[str]:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
176
        if GatewayConfig.compression == "zlib-payload":
177
            return inflator.decompress(msg)
178
179
        if GatewayConfig.compression == "zlib-stream":
180
            self.__buffer.extend(msg)
181
182
            if len(self.__buffer) < 4 or self.__buffer[-4:] != ZLIB_SUFFIX:
183
                return None
184
185
            msg = inflator.decompress(msg)
186
            self.__buffer = bytearray()
187
            return msg
188
189
        return None
190
191
    async def start_loop(self):
192
        """Instantiate the dispatcher, this will create a connection to the
193
        Discord websocket API on behalf of the client whose token has
194
        been passed.
195
        """
196
        for _try in count():
197
            try:
198
                self.__socket = await self.__session.ws_connect(
199
                    GatewayConfig.make_uri(self.url)
200
                )
201
                break
202
            except ClientConnectorError as e:
203
                if _try > GatewayConfig.MAX_RETRIES:
204
                    raise GatewayConnectionError from e
205
206
                _log.warning(
207
                    "%s Could not open websocket with Discord."
208
                    " Retrying in 15 seconds...",
209
                    self.shard_key
210
                )
211
                await sleep(15)
212
213
        _log.debug("%s Starting envent loop...", self.shard_key)
214
        await self.event_loop()
215
216
    async def event_loop(self):
217
        """Handles receiving messages and decompressing them if needed"""
218
        async for msg in self.__socket:
219
            if msg.type == WSMsgType.TEXT:
220
                await self.handle_data(msg.data)
221
            elif msg.type == WSMsgType.BINARY:
222
                # Message from transport compression that isn't complete returns None
223
                data = self.decompress_msg(msg.data)
224
                if data:
225
                    await self.handle_data(data)
226
            elif msg.type == WSMsgType.ERROR:
227
                raise GatewayError from self.__socket.exception()
228
229
        # The loop is broken when the gateway stops receiving messages.
230
        # The "error" op codes are in `self.__close_codes`. The rest of the
231
        # close codes are unknown issues (such as a unintended disconnect) so the
232
        # client should reconnect to the gateway.
233
        err = self.__close_codes.get(self.__socket.close_code)
234
235
        if err:
236
            raise err
237
238
        _log.debug(
239
            "%s Disconnected from Gateway due without any errors. Reconnecting.",
240
            self.shard_key
241
        )
242
        self.__should_reconnect = True
243
        self.start_loop()
244
245
    async def handle_data(self, data: Dict[Any]):
246
        """Method is run when a payload is received from the gateway.
247
        The message is expected to already have been decompressed.
248
        Handling the opcode is forked to the background so they aren't blocking.
249
        """
250
        payload = GatewayDispatch.from_string(data)
251
252
        # Op code -1 is activated on all payloads
253
        op_negative_one = self.__dispatch_handlers.get(-1)
254
        if op_negative_one:
255
            ensure_future(op_negative_one(payload))
256
257
        _log.debug(
258
            "%s %s GatewayDispatch with opcode %s received",
259
            self.shard_key,
260
            datetime.now(),
261
            payload.op
262
        )
263
264
        # Many events are sent with a `null` sequence. This sequence should not
265
        # be tracked.
266
        if payload.seq is not None:
267
            self.__sequence_number = payload.seq
268
            _log.debug("%s Set sequence number to %s", self.shard_key, payload.seq)
269
270
        handler = self.__dispatch_handlers.get(payload.op)
271
272
        if handler is None:
273
            raise UnhandledException(f"Opcode {payload.op} does not have a handler")
274
275
        ensure_future(handler(payload))
276
277
    async def handle_heartbeat_req(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
278
        """
279
        Opcode 1 - Instantly send a heartbeat.
280
        """
281
        self.send_next_heartbeat()
282
283
    async def handle_reconnect(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
284
        """
285
        Opcode 7 - Reconnect and resume immediately.
286
        """
287
        _log.debug(
288
            "%s Requested to reconnect to Discord. Closing session and attempting to"
289
            " resume...",
290
            self.shard_key
291
        )
292
293
        await self.__socket.close(code=1000)
294
        self.__should_reconnect = True
295
        await self.start_loop()
296
297
    async def handle_invalid_session(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
298
        """
299
        Opcode 9 - Invalid connection
300
        Attempt to relog. This is probably because the session was already invalidated
301
        when we tried to reconnect.
302
        """
303
        _log.debug("%s Invalid session, attempting to relog...", self.shard_key)
304
        self.__should_reconnect = False
305
        await self.start_loop()
306
307
    async def identify_and_handle_hello(self, payload: GatewayDispatch):
308
        """
309
        Opcode 10 - Hello there general kenobi
310
        Runs when we connect to the gateway for the first time and every time after.
311
        If the client thinks it should reconnect, the opcode 6 resume payload is sent
312
        instead of the opcode 2 hello payload. A new session is only started after a
313
        reconnect if pcode 9 is received.
314
315
        Successful reconnects are handled in the `resumed` middleware.
316
        """
317
        if self.__should_reconnect:
318
            _log.debug("%s Resuming connection with Discord", self.shard_key)
319
320
            await self.send(str(GatewayDispatch(
321
                6, {
322
                    "token": self.token,
323
                    "session_id": self.__session_id,
324
                    "seq": self.__sequence_number
325
                }
326
            )))
327
            return
328
329
        await self.send(str(
330
            GatewayDispatch(
331
                2, {
332
                    "token": self.token,
333
                    "intents": self.intents,
334
                    "properties": {
335
                        "$os": system(),
336
                        "$browser": __package__,
337
                        "$device": __package__
338
                    },
339
                    "compress": GatewayConfig.compressed(),
340
                    "shard": self.shard_key
341
                }
342
            )
343
        ))
344
        self.__heartbeat_interval = payload.data["heartbeat_interval"]
345
346
        # This process should already be forked to the background so there is no need to
347
        # `ensure_future()` here.
348
        self.start_heartbeat()
349
350
    async def handle_heartbeat(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
351
        """
352
        Opcode 11 - Heatbeat
353
        Track that the heartbeat has been received using shared state (Rustaceans would
354
        be very mad)
355
        """
356
        self.__has_received_ack = True
357
358
    async def send(self, payload: str):
359
        """
360
        Send a string object to the payload. Most of this method is just logging,
361
        the last line is the only one that matters for functionality.
362
        """
363
        safe_payload = payload.replace(self.token, "%s..." % self.token[:10])
364
365
        if self.__session_id:
366
            safe_payload = safe_payload.replace(
367
                self.__session_id, "%s..." % self.__session_id[:4]
368
            )
369
370
        _log.debug(
371
            "%s Sending payload: %s",
372
            self.shard_key,
373
            safe_payload
374
        )
375
376
        if self.__socket.closed:
377
            _log.debug(
378
                "%s Socket is closing. Payload not sent.",
379
                self.shard_key
380
            )
381
            return
382
383
        await self.__socket.send_str(payload)
384
385
    def start_heartbeat(self):
386
        """
387
        Starts the heartbeat if it is not already running.
388
        """
389
        if not self.__heartbeat_task or self.__heartbeat_task.cancelled():
390
            self.__heartbeat_task = ensure_future(self.__heartbeat_loop())
391
392
    def stop_heartbeat(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
393
        self.__heartbeat_task.cancel()
394
395
    def send_next_heartbeat(self):
396
        """
397
        It is expected to always be waiting for a hearbeat. By canceling that task,
398
        a heartbeat can be sent.
399
        """
400
        self.__wait_for_heartbeat.cancel()
401
402
    async def __heartbeat_loop(self):
403
        """
404
        The heartbeat is responsible for keeping the connection to Discord alive.
405
406
        Jitter is only random for the first heartbeat. It should be 1 every other
407
        heartbeat.
408
        """
409
        _log.debug("%s Starting heartbeat loop...", self.shard_key)
410
411
        # When waiting for first heartbeat, there hasn't been an ack received yet.
412
        # Set to true so the ack received check doesn't incorrectly fail.
413
        self.__has_received_ack = True
414
415
        for jitter in chain((random(),), repeat(1)):
416
            duration = self.__heartbeat_interval * jitter
417
418
            _log.debug(
419
                "%s %s sending heartbeat in %sms",
420
                self.shard_key, datetime.now(),
421
                duration
422
            )
423
424
            # Task is needed so waiting can be cancelled by op code 1
425
            self.__wait_for_heartbeat = create_task(
426
                sleep(duration / 1000)
427
            )
428
429
            await self.__wait_for_heartbeat
430
431
            if not self.__has_received_ack:
432
                # Close code is specified to be anything that is not 1000 in the docs.
433
                _log.debug(
434
                    "%s %s ack not received. Attempting to reconnect."
435
                    " Closing socket with close code 1001. \U0001f480",
436
                    datetime.now(),
437
                    self.shard_key
438
                )
439
                await self.__socket.close(code=1001)
440
                self.__should_reconnect = True
441
                # A new loop is started in the background while this one is stopped.
442
                ensure_future(self.start_loop())
443
                self.stop_heartbeat()
444
                return
445
446
            self.__has_received_ack = False
447
            await self.send(str(GatewayDispatch(1, data=self.__sequence_number)))
448
            _log.debug("%s sent heartbeat", self.shard_key)
449