Passed
Pull Request — main (#329)
by
unknown
01:46
created

pincer.core.gateway.Gateway.event_loop()   B

Complexity

Conditions 7

Size

Total Lines 28
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 18
dl 0
loc 28
rs 8
c 0
b 0
f 0
cc 7
nop 1
1
# Copyright Pincer 2021-Present
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
# Full MIT License can be found in `LICENSE` at the project root.
3
4
5
from __future__ import annotations
6
7
from asyncio import create_task, Task, ensure_future, sleep
8
from dataclasses import dataclass
9
from datetime import datetime
10
from itertools import repeat, count, chain
11
import logging
12
from platform import system
13
from random import random
14
from typing import TYPE_CHECKING, Any, Dict, Callable, Optional
15
from zlib import decompressobj, decompress
0 ignored issues
show
Unused Code introduced by
Unused decompress imported from zlib
Loading history...
16
17
from aiohttp import (
18
    ClientSession, WSMsgType, ClientConnectorError, ClientWebSocketResponse
19
)
20
21
from . import __package__
22
from ..utils.api_object import APIObject
23
from .._config import GatewayConfig
24
from ..core.dispatch import GatewayDispatch
25
from ..exceptions import (
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in ConnectionError.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
26
    InvalidTokenError, ConnectionError, GatewayError, UnhandledException
27
)
28
29
if TYPE_CHECKING:
30
    from ..objects.app.intents import Intents
31
    Handler = Callable[[GatewayDispatch], None]
32
33
_log = logging.getLogger(__package__)
34
35
ZLIB_SUFFIX = b'\x00\x00\xff\xff'
36
inflator = decompressobj()
37
38
39
@dataclass
40
class GatewayInfo(APIObject):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
41
    url: str
42
    shards: int
43
    session_start_limit: SessionStartLimit
44
45
46
@dataclass
47
class SessionStartLimit(APIObject):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
48
    total: int
49
    remaining: int
50
    reset_after: int
51
    max_concurrency: int
52
53
54
class Gateway:
0 ignored issues
show
best-practice introduced by
Too many instance attributes (18/7)
Loading history...
55
    """The Gateway handles all interactions with the Discord Websocket API.
56
    This also contains the main event loop, and handles the heartbeat.
57
58
    Running the Gateway will create a connection with the
59
    Discord Websocket API on behalf of the provided token.
60
61
    This token must be a bot token.
62
    (Which can be found on
63
    `<https://discord.com/developers/applications/>`_)
64
    """
65
66
    def __init__(
67
            self,
68
            token: str, *,
69
            intents: Intents,
70
            url: str,
71
            shard: int,
72
            num_shards: int
73
    ) -> None:
74
        if len(token) != 59:
75
            raise InvalidTokenError(
76
                "Discord Token must have exactly 59 characters."
77
            )
78
79
        self.token = token
80
        self.intents = intents
81
        self.url = url
82
        self.shard = shard
83
        self.num_shards = num_shards
84
        self.shard_key = [shard, num_shards]
85
86
        self.__dispatch_handlers: Dict[int, Handler] = {
87
            1: self.handle_heartbeat_req,
88
            7: self.handle_reconnect,
89
            9: self.handle_invalid_session,
90
            10: self.identify_and_handle_hello,
91
            11: self.handle_heartbeat
92
        }
93
94
        # 4000 and 4009 are not included. The client will reconnect when receiving
95
        # either.
96
        self.__close_codes: Dict[int, GatewayError] = {
97
            4001: GatewayError("Invalid opcode was sent"),
98
            4002: GatewayError("Invalid payload was sent."),
99
            4003: GatewayError("Payload was sent prior to identifying"),
100
            4004: GatewayError("Token is not valid"),
101
            4005: GatewayError(
102
                "Authentication was sent after client already authenticated"
103
            ),
104
            4007: GatewayError("Invalid sequence sent when starting new session"),
105
            4008: GatewayError("Client was rate limited"),
106
            4010: GatewayError("Invalid shard"),
107
            4011: GatewayError("Sharding required"),
108
            4012: GatewayError("Invalid API version"),
109
            4013: GatewayError("Invalid intents"),
110
            4014: GatewayError("Disallowed intents")
111
        }
112
113
        # ClientSession to be used for this Dispatcher
114
        self.__session: Optional[ClientSession] = None
115
116
        # This type `_WSRequestContextManager` isn't exposed by aiohttp.
117
        # `ClientWebSocketResponse` is a parent class.
118
        self.__socket: Optional[ClientWebSocketResponse] = None
119
120
        # Buffer used to store information in transport conpression.
121
        self.__buffer = bytearray()
122
123
        # The gateway can be disconnected from Discord. This variable stores if the
124
        # gateway should send a hello or reconnect.
125
        self.__should_reconnect: bool = False
126
127
        # The sequence number for the last received payload. This is used reconnecting.
128
        self.__sequence_number: int = 0
129
130
        # The heartbeat task
131
        self.__heartbeat_task: Optional[Task] = None
132
133
        # Keeps the Client waiting until the next heartbeat
134
        self.__wait_for_heartbeat: Optional[Task] = None
135
136
        # How long the client should wait between each Heartbeat.
137
        self.__heartbeat_interval: Optional[int] = None
138
139
        # Tracks whether the gateway has received an ack (opcode 11) since the last
140
        # heartbeat.
141
        #   True: An ack has been received
142
        #   False: No ack has been received. Attempt to reconnect with gateway,
143
        self.__has_received_ack: bool = True
144
145
        # Session ID received from `on_ready` event. It is set in the `on_ready`
146
        # middleware. This is used reconnecting.
147
        self.__session_id: Optional[str] = None
148
149
    def __del__(self):
150
        """Delete method ensures all connections are closed"""
151
        if self.__socket:
152
            create_task(self.__socket.close())
153
        if self.__session:
154
            create_task(self.__session.close())
155
156
    async def init_session(self):
157
        """
158
        Crates the ClientSession. ALWAYS run this function right after initializing
159
        a Gateway.
160
        """
161
        self.__session = ClientSession()
162
163
    def append_handlers(self, handlers: Dict[int, Handler]):
164
        """The Client that uses the handler can append their own methods. The gateway
165
        will run those methods when the specified opcode is received.
166
        """
167
        self.__dispatch_handlers = handlers | self.__dispatch_handlers
168
169
    def set_session_id(self, _id: str):
170
        """Session id is private for consistency"""
171
        self.__session_id = _id
172
173
    def decompress_msg(self, msg: bytes) -> Optional[str]:
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
174
        if GatewayConfig.compression == "zlib-payload":
175
            return inflator.decompress(msg)
176
177
        if GatewayConfig.compression == "zlib-stream":
178
            self.__buffer.extend(msg)
179
180
            if len(self.__buffer) < 4 or self.__buffer[-4:] != ZLIB_SUFFIX:
181
                return None
182
183
            msg = inflator.decompress(msg)
184
            self.__buffer = bytearray()
185
            return msg
186
187
        return None
188
189
    async def start_loop(self):
190
        """Instantiate the dispatcher, this will create a connection to the
191
        Discord websocket API on behalf of the client whose token has
192
        been passed.
193
        """
194
        for i in count():
195
            try:
196
                self.__socket = await self.__session.ws_connect(
197
                    GatewayConfig.make_uri(self.url)
198
                )
199
                break
200
            except ClientConnectorError as e:
201
                if i > GatewayConfig.MAX_RETRIES:
202
                    raise ConnectionError from e
203
204
                _log.warning(
205
                    "%s Could not open websocket with Discord."
206
                    "Retrying in 15 seconds...",
207
                    self.shard_key
208
                )
209
                await sleep(15)
210
211
        _log.debug("%s Starting envent loop...", self.shard_key)
212
        await self.event_loop()
213
214
    async def event_loop(self):
215
        """Handles receiving messages and decompressing them if needed"""
216
        async for msg in self.__socket:
217
            if msg.type == WSMsgType.TEXT:
218
                await self.handle_data(msg.data)
219
            elif msg.type == WSMsgType.BINARY:
220
                # Message from transport compression that isn't complete returns None
221
                data = self.decompress_msg(msg.data)
222
                if data:
223
                    await self.handle_data(data)
224
            elif msg.type == WSMsgType.ERROR:
225
                raise GatewayError from self.__socket.exception()
226
227
        # The loop is broken when the gateway stops receiving messages.
228
        # The "error" op codes are in `self.__close_codes`. The rest of the
229
        # close codes are unknown issues (such as a unintended disconnect) so the
230
        # client should reconnect to the gateway.
231
        err = self.__close_codes.get(self.__socket.close_code)
232
233
        if err:
234
            raise err
235
236
        _log.debug(
237
            "%s Disconnected from Gateway due without any errors. Reconnecting.",
238
            self.shard_key
239
        )
240
        self.__should_reconnect = True
241
        self.start_loop()
242
243
    async def handle_data(self, data: Dict[Any]):
244
        """Method is run when a payload is received from the gateway.
245
        The message is expected to already have been decompressed.
246
        Handling the opcode is forked to the background so they aren't blocking.
247
        """
248
        payload = GatewayDispatch.from_string(data)
249
250
        # Op code -1 is activated on all payloads
251
        op_negative_one = self.__dispatch_handlers.get(-1)
252
        if op_negative_one:
253
            ensure_future(op_negative_one(payload))
254
255
        _log.debug(
256
            "%s %s GatewayDispatch with opcode %s received",
257
            self.shard_key,
258
            datetime.now(),
259
            payload.op
260
        )
261
262
        # Many events are sent with a `null` sequence. This sequence should not
263
        # be tracked.
264
        if payload.seq is not None:
265
            self.__sequence_number = payload.seq
266
            _log.debug("%s Set sequence number to %s", self.shard_key, payload.seq)
267
268
        handler = self.__dispatch_handlers.get(payload.op)
269
270
        if handler is None:
271
            raise UnhandledException(f"Opcode {payload.op} does not have a handler")
272
273
        ensure_future(handler(payload))
274
275
    async def handle_heartbeat_req(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
276
        """
277
        Opcode 1 - Instantly send a heartbeat.
278
        """
279
        self.send_next_heartbeat()
280
281
    async def handle_reconnect(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
282
        """
283
        Opcode 7 - Reconnect and resume immediately.
284
        """
285
        _log.debug(
286
            "%s Requested to reconnect to Discord. Closing session and attempting to"
287
            "resume...",
288
            self.shard_key
289
        )
290
291
        await self.__socket.close(code=1000)
292
        self.__should_reconnect = True
293
        await self.start_loop()
294
295
    async def handle_invalid_session(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
296
        """
297
        Opcode 9 - Invalid connection
298
        Attempt to relog. This is probably because the session was already invalidated
299
        when we tried to reconnect.
300
        """
301
        _log.debug("%s Invalid session, attempting to relog...", self.shard_key)
302
        self.__should_reconnect = False
303
        await self.start_loop()
304
305
    async def identify_and_handle_hello(self, payload: GatewayDispatch):
306
        """
307
        Opcode 10 - Hello there general kenobi
308
        Runs when we connect to the gateway for the first time and every time after.
309
        If the client thinks it should reconnect, the opcode 6 resume payload is sent
310
        instead of the opcode 2 hello payload. A new session is only started after a
311
        reconnect if pcode 9 is received.
312
313
        Successful reconnects are handled in the `resumed` middleware.
314
        """
315
        if self.__should_reconnect:
316
            _log.debug("%s Resuming connection with Discord", self.shard_key)
317
318
            await self.send(str(GatewayDispatch(
319
                6,
320
                {
321
                    "token": self.token,
322
                    "session_id": self.__session_id,
323
                    "seq": self.__sequence_number
324
                }
325
            )))
326
            return
327
328
        await self.send(str(
329
            GatewayDispatch(
330
                2, {
331
                    "token": self.token,
332
                    "intents": self.intents,
333
                    "properties": {
334
                        "$os": system(),
335
                        "$browser": __package__,
336
                        "$device": __package__
337
                    },
338
                    "compress": GatewayConfig.compressed(),
339
                    "shard": self.shard_key
340
                }
341
            )
342
        ))
343
        self.__heartbeat_interval = payload.data["heartbeat_interval"]
344
345
        # This process should already be forked to the background so there is no need to
346
        # `ensure_future()` here.
347
        self.start_heartbeat()
348
349
    async def handle_heartbeat(self, payload: GatewayDispatch):
0 ignored issues
show
Unused Code introduced by
The argument payload seems to be unused.
Loading history...
350
        """
351
        Opcode 11 - Heatbeat
352
        Track that the heartbeat has been received using shared state (Rustaceans would
353
        be very mad)
354
        """
355
        self.__has_received_ack = True
356
357
    async def send(self, payload: str):
358
        """
359
        Send a string object to the payload. Most of this method is just logging,
360
        the last line is the only one that matters for functionality.
361
        """
362
        safe_payload = payload.replace(self.token, "%s..." % self.token[:10])
363
364
        if self.__session_id:
365
            safe_payload = safe_payload.replace(
366
                self.__session_id, "%s..." % self.__session_id[:4]
367
            )
368
369
        _log.debug(
370
            "%s Sending payload: %s",
371
            self.shard_key,
372
            safe_payload
373
        )
374
375
        if self.__socket.closed:
376
            _log.debug(
377
                "%s Socket is closing. Payload not sent.",
378
                self.shard_key
379
            )
380
            return
381
382
        await self.__socket.send_str(payload)
383
384
    def start_heartbeat(self):
385
        """
386
        Starts the heartbeat if it is not already running.
387
        """
388
        if not self.__heartbeat_task or self.__heartbeat_task.cancelled():
389
            self.__heartbeat_task = ensure_future(self.__heartbeat_loop())
390
391
    def stop_heartbeat(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
392
        self.__heartbeat_task.cancel()
393
394
    def send_next_heartbeat(self):
395
        """
396
        It is expected to always be waiting for a hearbeat. By canceling that task,
397
        a heartbeat can be sent.
398
        """
399
        self.__wait_for_heartbeat.cancel()
400
401
    async def __heartbeat_loop(self):
402
        """
403
        The heartbeat is responsible for keeping the connection to Discord alive.
404
405
        Jitter is only random for the first heartbeat. It should be 1 every other
406
        heartbeat.
407
        """
408
        _log.debug("%s Starting heartbeat loop...", self.shard_key)
409
410
        # When waiting for first heartbeat, there hasn't been an ack received yet.
411
        # Set to true so the ack received check doesn't incorrectly fail.
412
        self.__has_received_ack = True
413
414
        for jitter in chain((random(),), repeat(1)):
415
            duration = self.__heartbeat_interval * jitter
416
417
            _log.debug(
418
                "%s %s sending heartbeat in %sms",
419
                self.shard_key, datetime.now(),
420
                duration
421
            )
422
423
            # Task is needed so waiting can be cancelled by op code 1
424
            self.__wait_for_heartbeat = create_task(
425
                sleep(duration / 1000)
426
            )
427
428
            await self.__wait_for_heartbeat
429
430
            if not self.__has_received_ack:
431
                # Close code is specified to be anything that is not 1000 in the docs.
432
                _log.debug(
433
                    "%s %s ack not received. Attempting to reconnect."
434
                    "Closing socket with close code 1001. \U0001f480",
435
                    datetime.now(),
436
                    self.shard_key
437
                )
438
                await self.__socket.close(code=1001)
439
                self.__should_reconnect = True
440
                # A new loop is started in the background while this one is stopped.
441
                ensure_future(self.start_loop())
442
                self.stop_heartbeat()
443
                return
444
445
            self.__has_received_ack = False
446
            await self.send(str(GatewayDispatch(1, data=self.__sequence_number)))
447
            _log.debug("%s sent heartbeat", self.shard_key)
448